diff --git a/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h b/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h index d7967189493c..d9111878d429 100644 --- a/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h +++ b/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h @@ -95,7 +95,7 @@ class SgDNNLRemoveCastsSelector : public SubgraphSelectorV2 { } void Reset() override { - status_ = kFail; + status_ = kExpand; castDtype = -1; } }; @@ -105,7 +105,7 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty { SgDNNLRemoveCastsProperty() {} static SubgraphPropertyPtr Create() { - static const std::string& name = "Remove casts optimization pass"; + static const std::string& name = "Remove Casts optimization pass"; auto property = std::make_shared(); property->SetAttr("property_name", name); property->SetAttr("inference_only", true); @@ -137,6 +137,15 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty { auto selector = std::make_shared(); return selector; } + + void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node, + std::vector* output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{subgraph_node, entry_ptr->index, 0}; + } + } }; } // namespace op