Skip to content

Commit

Permalink
transform ambiguous convert
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Dec 6, 2024
1 parent ea72f30 commit 158d71f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "snippets/pass/align_element_types.hpp"
#include "snippets/pass/reduce_to_snippets_reduce.hpp"
#include "snippets/pass/gn_decomposition.hpp"
#include "snippets/pass/transform_convert.hpp"

#include "snippets/runtime_configurator.hpp"
#include "snippets/utils/utils.hpp"
Expand Down Expand Up @@ -428,6 +429,7 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();

manager.register_positioned_passes(backend_passes);
manager.register_pass<snippets::pass::TransformConvertToConvertTruncation>();
manager.run_passes(body_ptr());
}

Expand Down
5 changes: 2 additions & 3 deletions src/common/snippets/src/pass/transform_convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@

ov::snippets::pass::TransformConvertToConvertTruncation::TransformConvertToConvertTruncation() {
MATCHER_SCOPE(TransformConvertToConvertTruncation);
auto convert = std::make_shared<ov::pass::pattern::op::Label>(ov::pass::pattern::any_input(),
auto convert_pattern = std::make_shared<ov::pass::pattern::op::Label>(ov::pass::pattern::any_input(),
[](const std::shared_ptr<const Node> &n) {
return ov::is_type<ov::opset1::Convert>(n) &&
!ov::is_type<op::ConvertTruncation>(n) &&
!ov::is_type<op::ConvertSaturation>(n);
});

register_matcher(std::make_shared<ov::pass::pattern::Matcher>(
ov::pass::pattern::wrap_type<ov::opset1::Convert>(), matcher_name), [](ov::pass::pattern::Matcher &m) {
register_matcher(std::make_shared<ov::pass::pattern::Matcher>(convert_pattern, matcher_name), [](ov::pass::pattern::Matcher &m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::TransformConvertToConvertTruncation")
const auto root = m.get_match_root();
const auto convert = ov::as_type_ptr<ov::opset1::Convert>(root);
Expand Down

0 comments on commit 158d71f

Please sign in to comment.