From 5a87a0cfdb76b8c8db71b941f7555a1e1d261865 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 8 Jul 2022 15:37:54 -0700 Subject: [PATCH] ci test --- .../subgraph/dnnl/dnnl_remove_casts_property.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h b/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h index d7967189493c..d9111878d429 100644 --- a/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h +++ b/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h @@ -95,7 +95,7 @@ class SgDNNLRemoveCastsSelector : public SubgraphSelectorV2 { } void Reset() override { - status_ = kFail; + status_ = kExpand; castDtype = -1; } }; @@ -105,7 +105,7 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty { SgDNNLRemoveCastsProperty() {} static SubgraphPropertyPtr Create() { - static const std::string& name = "Remove casts optimization pass"; + static const std::string& name = "Remove Casts optimization pass"; auto property = std::make_shared(); property->SetAttr("property_name", name); property->SetAttr("inference_only", true); @@ -137,6 +137,15 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty { auto selector = std::make_shared(); return selector; } + + void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node, + std::vector* output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{subgraph_node, entry_ptr->index, 0}; + } + } }; } // namespace op