Skip to content

Commit 871d3fb

Browse files
authored
Fix a bug in ReluClip fusion (#9764)
1 parent b409cbe commit 871d3fb

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

onnxruntime/core/optimizer/relu_clip_fusion.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,17 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
9696
mutable_next_node->ClearAttribute("min");
9797
mutable_next_node->AddAttribute("min", 0.f);
9898
} else {
99+
// Add the initialized tensor to the graph
99100
graph.AddInitializedTensor(replacement_min);
101+
102+
// Create a corresponding NodeArg for the initialized tensor
103+
ONNX_NAMESPACE::TypeProto t;
104+
t.mutable_tensor_type()->set_elem_type(replacement_min.data_type());
105+
NodeArg* replacement_min_nodearg = &graph.GetOrCreateNodeArg(replacement_min.name(), &t);
106+
107+
// Replace the input def at the appropriate index of the Clip node
100108
auto& mutable_input_defs = mutable_next_node->MutableInputDefs();
101-
NodeArg* replacement_min_nodearg = graph.GetNodeArg(replacement_min.name());
109+
102110
if (mutable_input_defs.size() == 1) { // Clip node only has the required 'input' so add optional 'min' input
103111
mutable_input_defs.push_back(replacement_min_nodearg);
104112
mutable_next_node->MutableInputArgsCount().push_back(1);

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,6 +1941,30 @@ TEST_F(GraphTransformationTests, ReluClip11Fusion) {
19411941
}
19421942
}
19431943

1944+
TEST_F(GraphTransformationTests, ReluClip11FusionGHIssue9753) {
1945+
auto model_uri = MODEL_FOLDER "fusion/relu_clip_fusion_gh_issue_9753.onnx";
1946+
std::shared_ptr<Model> model;
1947+
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
1948+
Graph& graph = model->MainGraph();
1949+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
1950+
1951+
// The model contains one Relu and one Clip
1952+
ASSERT_TRUE(op_to_count["Relu"] == 1);
1953+
ASSERT_TRUE(op_to_count["Clip"] == 1);
1954+
1955+
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
1956+
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<FuseReluClip>()));
1957+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
1958+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
1959+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
1960+
1961+
op_to_count = CountOpsInGraph(graph);
1962+
1963+
// After fusion, the model only contains Clip.
1964+
ASSERT_TRUE(op_to_count["Relu"] == 0);
1965+
ASSERT_TRUE(op_to_count["Clip"] == 1);
1966+
}
1967+
19441968
// Test Reshape Fusion with 2 constant initializers for Concat inputs.
19451969
TEST_F(GraphTransformationTests, ReshapeFusionTest) {
19461970
auto model_uri = MODEL_FOLDER "fusion/reshape.onnx";
Binary file not shown.

0 commit comments

Comments
 (0)