Skip to content

Commit

Permalink
[Transformations] Make ov::ModelPass transformations execute recursiv…
Browse files Browse the repository at this point in the history
…ely (openvinotoolkit#23058)

[Transformations] Make ov::ModelPass transformations execute recursively

### Details:
Some ov::ModelPass transformations lack recursive execution for
subgraphs leaving it not processed.
Add the required recursive call for MultiSubGraphOp operations.

### Tickets:
Ticket: CVS-116659

Signed-off-by: Andrii Staikov <[email protected]>

Signed-off-by: Andrii Staikov <[email protected]>
  • Loading branch information
CuriousPanCake authored Mar 12, 2024
1 parent 461a138 commit d81c97f
Show file tree
Hide file tree
Showing 19 changed files with 100 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class ov::pass::low_precision::PropagateSharedValue : public ov::pass::ModelPass
std::vector<std::shared_ptr<ov::Node>> nodes(f->get_ordered_ops());
for (auto it = nodes.begin(); it != nodes.end(); it++) {
const std::shared_ptr<Node> node = *it;

ov::op::util::process_subgraph(*this, node);

if (ov::is_type<opset1::FakeQuantize>(node)) {
assert(node->get_output_size() == 1ul);
auto& outputRtInfo = node->output(0).get_rt_info();
Expand Down
3 changes: 3 additions & 0 deletions src/common/snippets/src/pass/propagate_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "snippets/itt.hpp"
#include "snippets/utils.hpp"
#include "openvino/core/rt_info.hpp"
#include "transformations/utils/utils.hpp"

#include <assert.h>
#include <memory>
Expand All @@ -29,6 +30,8 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<

bool was_updated = false;
for (const auto& op : f->get_ordered_ops()) {
ov::op::util::process_subgraph(*this, op);

auto type_info = op->get_type_info();
auto exec = target_machine->get_supported_precisions(type_info);
const auto& supported_precisions = exec(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>

TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output);

TRANSFORMATIONS_API bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node);

template <typename T>
ov::pass::pattern::op::ValuePredicate constant_predicate(std::function<bool(const std::vector<T>&)> predicate) {
return pass::pattern::op::as_value_predicate([=](std::shared_ptr<Node> n) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

#include "openvino/cc/pass/itt.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/utils/utils.hpp"

bool ov::pass::FusedNamesCleanup::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(FusedNamesCleanup);

for (auto& node : f->get_ordered_ops()) {
ov::op::util::process_subgraph(*this, node);

RTMap& rt_info = node->get_rt_info();
auto it = rt_info.find(ov::FusedNames::get_type_info_static());
if (it != rt_info.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "openvino/pass/manager.hpp"
#include "transformations/common_optimizations/shared_ops_optimization.hpp"
#include "transformations/op_conversions/convert_slice_to_strided_slice.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov;

Expand All @@ -27,11 +28,8 @@ bool ov::pass::UselessSliceEraser::run_on_model(const std::shared_ptr<ov::Model>
bool rewritten = false;
for (auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
rewritten |= run_on_model(sub_graph);
}
}
rewritten = ov::op::util::process_subgraph(*this, node) || rewritten;

bool is_slice = ov::is_type<ov::op::v1::StridedSlice>(node) || ov::is_type<ov::op::v8::Slice>(node);
if (!is_slice || node->get_output_partial_shape(0).is_dynamic() ||
node->get_input_partial_shape(0).is_dynamic())
Expand All @@ -45,7 +43,7 @@ bool ov::pass::UselessSliceEraser::run_on_model(const std::shared_ptr<ov::Model>
if (!std::any_of(strides.begin(), strides.end(), [](int64_t strd) {
return strd < 0;
})) {
rewritten |= replace_output_update_name(node->output(0), node->input_value(0));
rewritten = replace_output_update_name(node->output(0), node->input_value(0)) || rewritten;
}
}
}
Expand Down Expand Up @@ -102,11 +100,8 @@ bool ov::pass::GroupedStridedSliceOptimizer::run_on_model(const std::shared_ptr<
std::map<ov::Output<Node>, std::vector<planned_slice>> source_to_ss_with_plan;
for (const auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
graph_rewritten |= run_on_model(sub_graph);
}
}
graph_rewritten = ov::op::util::process_subgraph(*this, node) || graph_rewritten;

if (auto ss = std::dynamic_pointer_cast<ov::op::v1::StridedSlice>(node)) {
auto slice_plan = get_slice_plan(ss);
if (slice_plan == op::util::SlicePlan())
Expand Down Expand Up @@ -291,12 +286,8 @@ bool ov::pass::GroupedSliceToVSplitOptimization::run_on_model(const std::shared_
std::vector<OutputWithAxis> ordered_outputs;
for (const auto& node : model->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto multi_subgraph_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph)
graph_rewritten |= run_on_model(sub_graph);
}
}
graph_rewritten = ov::op::util::process_subgraph(*this, node) || graph_rewritten;

if (auto op = ov::as_type_ptr<op::v8::Slice>(node)) {
SliceAttrs attributes{};
if (slice_is_suitable_for_optimization(op, attributes)) {
Expand Down Expand Up @@ -365,8 +356,9 @@ bool ov::pass::GroupedSliceToVSplitOptimization::run_on_model(const std::shared_

auto i = 0;
for (auto& slice_with_attrs : attributes) {
graph_rewritten |=
ov::replace_output_update_name(slice_with_attrs.slice->output(0), variadic_split->output(i));
graph_rewritten =
ov::replace_output_update_name(slice_with_attrs.slice->output(0), variadic_split->output(i)) ||
graph_rewritten;
ov::copy_runtime_info(slice_with_attrs.slice, variadic_split);
++i;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "openvino/op/util/binary_elementwise_arithmetic.hpp"
#include "openvino/op/util/pad_base.hpp"
#include "openvino/op/util/unary_elementwise_arithmetic.hpp"
#include "transformations/utils/utils.hpp"

bool ov::pass::ReverseShapeAndTypeInfer::inherit_output_shape(const std::shared_ptr<ov::Node>& node,
const std::vector<size_t>& input_idxs) {
Expand Down Expand Up @@ -70,6 +71,8 @@ bool ov::pass::ReverseShapeAndTypeInfer::run_on_model(const std::shared_ptr<ov::
auto ops = f->get_ordered_ops();
for (auto it = ops.rbegin(); it != ops.rend(); ++it) {
const auto& op = *it;
is_changed = ov::op::util::process_subgraph(*this, op) || is_changed;

auto output_shape = op->get_output_partial_shape(0);
auto output_type = op->get_output_element_type(0);
if (const auto& param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(op)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ bool shared_node_optimization(const shared_ptr<Model>& model) {
if (auto multi_subgraph_op = dynamic_pointer_cast<op::util::MultiSubGraphOp>(op)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph)
rewritten |= shared_node_optimization(sub_graph);
rewritten = shared_node_optimization(sub_graph) || rewritten;
}
}
for (auto& output : op->outputs()) {
Expand All @@ -136,7 +136,8 @@ bool shared_node_optimization(const shared_ptr<Model>& model) {
continue;
const auto& child_op = shared_nodes[j];
if (nodes_are_equal(root_op, child_op)) {
rewritten |= replace_output_update_name(child_op->output(0), root_op->output(0));
rewritten =
replace_output_update_name(child_op->output(0), root_op->output(0)) || rewritten;
visited_nodes[j] = true;
}
}
Expand All @@ -154,7 +155,7 @@ bool shape_of_upgrade(const shared_ptr<Model>& model) {
if (auto multi_subgraph_op = dynamic_pointer_cast<op::util::MultiSubGraphOp>(op)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph)
rewritten |= shape_of_upgrade(sub_graph);
rewritten = shape_of_upgrade(sub_graph) || rewritten;
}
} else if (auto v1_shape_of = ov::as_type_ptr<v0::ShapeOf>(op)) {
auto v3_shape_of = std::make_shared<v3::ShapeOf>(v1_shape_of->input_value(0), element::i64);
Expand All @@ -171,6 +172,6 @@ bool pass::SharedOpOptimization::run_on_model(const shared_ptr<Model>& model) {
RUN_ON_FUNCTION_SCOPE(SharedOpOptimization);

bool rewritten = shape_of_upgrade(model);
rewritten |= shared_node_optimization(model);
rewritten = shared_node_optimization(model) || rewritten;
return rewritten;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
bool ov::pass::UnrollTensorIterator::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(UnrollTensorIterator);
for (const auto& op : f->get_ops()) {
ov::op::util::process_subgraph(*this, op);

auto sub_graph_op = std::dynamic_pointer_cast<op::util::SubGraphOp>(op);
if (!sub_graph_op || transformation_callback(sub_graph_op)) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,18 @@ bool convert_function_precision(const std::shared_ptr<Model>& f,
for (auto& node : ops) {
if (skip_precision_sensitive && fp16_compression_is_disabled(node) && has_fp16_compression)
continue;
is_changed |= convert_node_input_precision(node, precisions, type_to_extend);
is_changed = convert_node_input_precision(node, precisions, type_to_extend) || is_changed;
}

for (const auto& param : f->get_parameters()) {
if (skip_precision_sensitive && fp16_compression_is_disabled(param) && has_fp16_compression)
continue;
is_changed |= fuse_type_to_parameter(param, precisions, convert_input_output_precision);
is_changed = fuse_type_to_parameter(param, precisions, convert_input_output_precision) || is_changed;
}

if (convert_input_output_precision || store_original_precision_as_rt_attribute) {
for (const auto& variable : f->get_variables()) {
is_changed |= fuse_type_to_variable(variable, precisions);
is_changed = fuse_type_to_variable(variable, precisions) || is_changed;
}
}

Expand Down Expand Up @@ -272,17 +272,18 @@ bool convert_function_precision(const std::shared_ptr<Model>& f,
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
size_t sub_graphs_num = sub_graph_node->get_internal_subgraphs_size();
for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) {
is_changed |= convert_function_precision(sub_graph_node->get_function(static_cast<int>(sub_graph_ind)),
type_to_fuse,
type_to_extend,
precisions,
const_to_internal_output,
has_fp16_compression,
skip_precision_sensitive,
is_changed || is_output_precision_changed,
true,
true,
store_original_precision_as_rt_attribute);
is_changed = convert_function_precision(sub_graph_node->get_function(static_cast<int>(sub_graph_ind)),
type_to_fuse,
type_to_extend,
precisions,
const_to_internal_output,
has_fp16_compression,
skip_precision_sensitive,
is_changed || is_output_precision_changed,
true,
true,
store_original_precision_as_rt_attribute) ||
is_changed;
}
}
// if convert_input_output_precision flag is set, we don't need to preserve the original precision
Expand All @@ -293,16 +294,17 @@ bool convert_function_precision(const std::shared_ptr<Model>& f,
node->revalidate_and_infer_types();
continue;
}
is_output_precision_changed |= convert_node_output_precision(node,
precisions,
type_to_fuse,
const_to_internal_output,
is_changed || is_output_precision_changed);
is_output_precision_changed = convert_node_output_precision(node,
precisions,
type_to_fuse,
const_to_internal_output,
is_changed || is_output_precision_changed) ||
is_output_precision_changed;
}

if (is_output_precision_changed) {
ops = f->get_ordered_ops();
is_changed |= is_output_precision_changed;
is_changed = is_output_precision_changed || is_changed;
}

if (!is_subgraph) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ bool ov::pass::AlignMixedFP32FP16Types::run_on_model(const std::shared_ptr<ov::M
if (!fp16_compression_is_disabled(node))
continue;

is_changed |= insert_converts_before_if_needed(node);
is_changed |= insert_converts_after_if_needed(node);
is_changed = insert_converts_before_if_needed(node) || is_changed;
is_changed = insert_converts_after_if_needed(node) || is_changed;
}

return is_changed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
#include "openvino/op/util/sub_graph_base.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/rt_info/primitives_priority_attribute.hpp"
#include "transformations/utils/utils.hpp"

bool ov::pass::InitNodeInfo::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(InitNodeInfo);

for (auto& node : f->get_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
run_on_model(sub_graph);
}
}
ov::op::util::process_subgraph(*this, node);

auto& rtInfo = node->get_rt_info();
rtInfo.emplace(FusedNames::get_type_info_static(), FusedNames{node->get_friendly_name()});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ bool relax_batch_for_initial_states_of_lstm_in_ti(const shared_ptr<ov::op::v0::T
return rewritten;
if (auto init_hidden_state = dynamic_pointer_cast<ov::op::v0::Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
auto outer_init_hidden_state_input = get_outer_input_of_ti_by_parameter(init_hidden_state, ti);
rewritten |= broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node);
rewritten = broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node) || rewritten;
}
if (auto init_cell_state = dynamic_pointer_cast<ov::op::v0::Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
auto outer_init_cell_state_input = get_outer_input_of_ti_by_parameter(init_cell_state, ti);
rewritten |= broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node);
rewritten = broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node) || rewritten;
}
return rewritten;
}
Expand All @@ -151,8 +151,8 @@ bool relax_batch_for_initial_states_of_lstm(const shared_ptr<ov::op::v4::LSTMCel
make_shared<ov::op::v8::Gather>(batched_shape,
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}),
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {0}));
rewritten |= broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node);
rewritten |= broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node);
rewritten = broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node) || rewritten;
rewritten = broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node) || rewritten;
return rewritten;
}

Expand All @@ -163,13 +163,11 @@ bool ov::pass::LSTMStatesBroadcast::run_on_model(const shared_ptr<ov::Model>& f)
bool rewritten = false;
for (auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (const auto& sub_graph_node = dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
if (const auto& sub_graph = sub_graph_node->get_function())
rewritten |= run_on_model(sub_graph);
rewritten = ov::op::util::process_subgraph(*this, node) || rewritten;

// Case without TI (LSTMCell and Constant are in the same ov::Model)
if (const auto& lstm_cell = dynamic_pointer_cast<ov::op::v4::LSTMCell>(node))
rewritten |= relax_batch_for_initial_states_of_lstm(lstm_cell);
rewritten = relax_batch_for_initial_states_of_lstm(lstm_cell) || rewritten;

// Case with TI (LSTMCell and Constant are in different ov::Model objects)
if (auto ti = dynamic_pointer_cast<ov::op::v0::TensorIterator>(node)) {
Expand All @@ -178,7 +176,7 @@ bool ov::pass::LSTMStatesBroadcast::run_on_model(const shared_ptr<ov::Model>& f)
continue;
for (const auto& body_node : body->get_ordered_ops())
if (const auto& lstm_cell = dynamic_pointer_cast<ov::op::v4::LSTMCell>(body_node))
rewritten |= relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell);
rewritten = relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell) || rewritten;
}
}
return rewritten;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "openvino/op/squeeze.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
#include "openvino/op/util/symbolic_info.hpp"
#include "transformations/utils/utils.hpp"

namespace {
void update_label(const ov::EqTable& table, ov::label_t& label) {
Expand Down Expand Up @@ -250,10 +251,7 @@ bool ov::pass::OptimizeLabelsUsedAsValues::run_on_model(const std::shared_ptr<ov
continue;

// LTS maps aren't shared with sub-graphs because inner graph can not access outer graph for label sources
if (auto multi_subgraph_op = std::dynamic_pointer_cast<ov::op::util::MultiSubGraphOp>(op))
for (const auto& sub_graph : multi_subgraph_op->get_functions())
if (sub_graph)
run_on_model(sub_graph);
ov::op::util::process_subgraph(*this, op);

for (auto& output : op->outputs()) {
optimize_value_usage(output, label_shape_source, label_value_source);
Expand Down
Loading

0 comments on commit d81c97f

Please sign in to comment.