From 1d5ca89b947004caa78a692ba6b4ae69b0182605 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 21 Jan 2025 15:54:27 -0800 Subject: [PATCH 01/24] init --- .../core/framework/compute_capability.h | 17 + onnxruntime/core/optimizer/constant_folding.h | 2 + .../constant_folding_dq_node.cc | 359 ++++++++++++++++++ .../constant_folding_dq_node.h | 45 +++ .../shared_library/provider_interfaces.h | 2 + .../core/session/provider_bridge_ort.cc | 30 ++ 6 files changed, 455 insertions(+) create mode 100644 onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc create mode 100644 onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index 5f21ba2f013e0..de479ed029304 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -2,8 +2,11 @@ // Licensed under the MIT License. #pragma once +#include #include "core/common/common.h" #include "core/graph/indexed_sub_graph.h" +#include "core/graph/graph.h" + namespace onnxruntime { // A structure encodes a subgraph and the method to run it. @@ -21,5 +24,19 @@ struct ComputeCapability { ComputeCapability(std::unique_ptr t_sub_graph) : sub_graph(std::move(t_sub_graph)) {} + + // optional function to optimize this ComputeCapability + // this will be called by ORT once the ComputeCapability is assigned to the EP + // Optimization: std::function + std::function optimization_func; + + // optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized. + // when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made. + // IndexedSubGraph.nodes: + // - update based on RemovedNode/AddNode calls + // IndexedSubGraph.MetaDef (if present): + // - inputs and outputs will be unchanged + // - constant_initializers MAY change if we constant fold an initializer during optimization + std::vector> nodes_to_optimize; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 14eb2a9c5f06b..4fe2658e89cda 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -28,6 +28,8 @@ class ConstantFolding : public GraphTransformer { const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; + virtual bool AllowConstantFolding(const Node&) const { return true; } + private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc new file mode 100644 index 0000000000000..da31176f7a48c --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc @@ -0,0 +1,359 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/optimizer_execution_frame.h" +#include "core/optimizer/utils.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensorprotoutils.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { + +ConstantFoldingDQ::ConstantFoldingDQ(const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& compatible_execution_providers, + const InlinedHashSet& excluded_initializers, + const InlinedHashSet& node_index_in_compute_capability) noexcept + : ConstantFolding(execution_provider,skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers), + node_index_in_compute_capability_(node_index_in_compute_capability), + skip_dequantize_linear_(skip_dequantize_linear), + config_options_(config_options), + excluded_initializers_(excluded_initializers), + execution_provider_(execution_provider) { +} + +// We need to handle a Shape node separately as the input doesn't need to be a constant initializer for +// Shape to be able to be constant folded. +static bool ConstantFoldShapeNode(Graph& graph, Node& node) { + // Opset-15 Shape supports slicing using a 'start' and 'end' attribute + const auto& shape_attributes = node.GetAttributes(); + + int64_t start = 0; + int64_t end = std::numeric_limits::max(); + + for (const auto& attr : shape_attributes) { + if (attr.first == "start") { + start = attr.second.i(); + } else if (attr.first == "end") { + end = attr.second.i(); + } + } + + auto shape = node.MutableInputDefs()[0]->Shape(); + bool is_concrete_shape = true; + std::vector dim_values; + if (shape != nullptr) { + for (int dim_index = 0; dim_index < shape->dim_size(); dim_index++) { + auto dim = shape->dim(dim_index); + if (!utils::HasDimValue(dim)) { + is_concrete_shape = false; + break; + } + dim_values.push_back(dim.dim_value()); + } + } else { + is_concrete_shape = false; + } + + if (is_concrete_shape) { + int64_t rank = static_cast(dim_values.size()); + + // We ascertain the "true" starts/ends (if they were provided) + // Opset-15 Shape op supports slicing shape values + + // Deal with negatives and clamp + start = start < 0 ? start + rank : start; + start = start < 0 ? 0 : ((start > rank) ? rank : start); + + end = end < 0 ? end + rank : end; + end = end < 0 ? 0 : ((end > rank) ? rank : end); + + int64_t slice_length = end - start; + size_t clamped_slice_length = slice_length < 0 ? 0 : static_cast(slice_length); + + ONNX_NAMESPACE::TensorProto shape_constant; + auto* constant_arg_out = node.MutableOutputDefs()[0]; + shape_constant.set_name(constant_arg_out->Name()); + shape_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + shape_constant.add_dims(clamped_slice_length); + utils::SetRawDataInTensorProto(shape_constant, dim_values.data() + start, clamped_slice_length * sizeof(int64_t)); + ONNX_NAMESPACE::TensorShapeProto result_shape; + result_shape.add_dim()->set_dim_value(clamped_slice_length); + constant_arg_out->SetShape(result_shape); + graph.AddInitializedTensor(shape_constant); + } + + return is_concrete_shape; // convert to constant if this is true +} + +// This function inlines the appropriate subgraph. It does not literally fold it. +static Status ConstantFoldIfNode(Graph& graph, Node& if_node, const logging::Logger& logger, bool& folded) { + folded = false; + // First, find out which subgraph to inline + // We need to fetch the constant argument. + assert(if_node.InputDefs().size() == 1); + const auto* condition_def = if_node.InputDefs()[0]; + + // We need to check if the condition is a constant. + constexpr bool check_outer_scope_true = true; + const ONNX_NAMESPACE::TensorProto* initializer = + graph.GetConstantInitializer(condition_def->Name(), check_outer_scope_true); + if (initializer == nullptr) { + return Status::OK(); + } + + // This is a boolean initializer with a single element. + Initializer condition{*initializer}; + ORT_RETURN_IF_NOT(condition.size() == 1, "If node condition initializer: `", condition_def->Name(), + "' is expected to have a single boolean element"); + + const bool condition_value = *condition.data(); + + auto status = graph.InlineIfSubgraph(condition_value, if_node, logger); + + if (!status.IsOK()) { + LOGS(logger, WARNING) << "Unable to constant fold. InlineIfSubgraph failed " + << " node '" << if_node.Name() << "': " + << status.ErrorMessage(); + return status; + } + + graph_utils::RemoveNodeOutputEdges(graph, if_node); + graph.RemoveNode(if_node.Index()); + + folded = true; + return status; +} + +bool ConstantFoldingDQ::AllowConstantFolding(const Node& node) const { + if (node_index_in_compute_capability_.find(node.Index()) != node_index_in_compute_capability_.end()) { + return true; + } + return false; +} + +Status ConstantFoldingDQ::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + bool have_updated_nodes = false; + GraphViewer graph_viewer(graph); + auto& order = graph_viewer.GetNodesInTopologicalOrder(); + +#if !defined(DISABLE_SPARSE_TENSORS) + std::function is_sparse_initializer_check = [&graph](const std::string& name) -> bool { + return graph.IsSparseInitializer(name); + }; +#endif + + for (NodeIndex i : order) { + auto* node = graph.GetNode(i); + if (!node) { + continue; + } + + if (!AllowConstantFolding(*node)) { + continue; + } + + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); + + // Updating a node may allow shape inferencing to infer output shapes of following nodes, + // so re-run the shape inferencing. use have_updated_nodes as that only applies to this Graph + // (vs. 'modified' which is passed into subgraphs and applies to the main graph and all subgraphs) + // Ignore any control flow node containing subgraphs as UpdateShapeInference is not intended to be used on it. + if (have_updated_nodes && !node->ContainsSubgraph()) { + ORT_RETURN_IF_ERROR(graph.UpdateShapeInference(*node)); + } + + bool converted_to_constant = false; + if (node->OpType().compare("If") == 0) { + // This process constant folds the If node only, + // but inlines the nodes of the corresponding branch graph. + // It does not convert the node to a constant in a common sense. + // We call it constant folding because the `If` node constant condition + // may enable us to inline the corresponding branch graph. + bool folded = false; + ORT_RETURN_IF_ERROR(ConstantFoldIfNode(graph, *node, logger, folded)); + if (folded) { + // Node removal is done within ConstantFoldIfNode() + modified = true; + have_updated_nodes = true; + } + } else if (node->OpType().compare("Shape") == 0) { + converted_to_constant = ConstantFoldShapeNode(graph, *node); + } else { + InitializedTensorSet constant_inputs; + + // Check if constant folding can be applied on this node. + const auto can_constant_fold_node = [&](const Node& n, bool skip_inputs_constant_check = false) { + return graph_utils::IsSupportedProvider(n, GetCompatibleExecutionProviders()) && + optimizer_utils::IsOperationDeterministic(n.Domain(), n.OpType()) && + // constant folding does not support executing a node that includes subgraphs (control flow operators, + // such as If/Loop/Scan, fall into this category). individual nodes in the subgraph will be processed + // by the Recurse call above + !n.ContainsSubgraph() && + (skip_inputs_constant_check || + graph_utils::AllNodeInputsAreConstant(graph, n, constant_inputs, excluded_initializers_)); + }; + + if (!can_constant_fold_node(*node)) { + continue; + } + + // if skip_dequantize_linear is true we want to maintain QDQ node units so avoid constant folding + // DequantizeLinear unless we can fold the whole QDQ node unit + if (skip_dequantize_linear_ && node->OpType() == "DequantizeLinear") { + bool can_constant_fold_qdq_node_unit = false; + + // Simplest scenario where the whole QDQ node unit of (DQ -> X -> Q) can be constant folded is if: + // - the DQ node does not produce a graph output, and its output is only consumed by X + // - X is a deterministic node with a single input and single output + // - the output from X is not a graph output and is only consumed by a Q node + if (optimizer_utils::CheckOutputEdges(graph, *node, 1)) { // DQ does not produce graph output, single consumer + const Node& node_x = *node->OutputNodesBegin(); + if (node_x.InputDefs().size() == 1 && + node_x.OutputDefs().size() == 1 && + optimizer_utils::CheckOutputEdges(graph, node_x, 1)) { + const Node& probably_q = *node_x.OutputNodesBegin(); + + if (probably_q.OpType() == "QuantizeLinear") { + // the inputs to these nodes are not const yet, but will be if we constant fold, + // so set skip_const_check to simulate that having happened + constexpr bool skip_const_check = true; + can_constant_fold_qdq_node_unit = can_constant_fold_node(node_x, skip_const_check) && + can_constant_fold_node(probably_q, skip_const_check); + } + } + } + + if (!can_constant_fold_qdq_node_unit) { + continue; + } + } + +#if !defined(DISABLE_SPARSE_TENSORS) + // Create execution frame for executing constant nodes. + OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, + is_sparse_initializer_check, logger); +#else + // Create execution frame for executing constant nodes. + OptimizerExecutionFrame::Info info( + {node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; }, + logger); +#endif + + std::vector fetch_mlvalue_idxs; + for (const auto* node_out : node->OutputDefs()) { + fetch_mlvalue_idxs.push_back(info.GetMLValueIndex(node_out->Name())); + } + + const bool node_on_cpu_ep = node->GetExecutionProviderType() == kCpuExecutionProvider; + + std::unique_ptr kernel; + + if (!node_on_cpu_ep) { + // We need to copy the string here instead of taking a reference to it since node->SetExecutionProviderType + // will change the value of the reference + auto ep_type = node->GetExecutionProviderType(); + + // override the EP assigned to the node so that it will use the CPU kernel for Compute. + node->SetExecutionProviderType(kCpuExecutionProvider); + + kernel = info.CreateKernel(node, config_options_); + + // undo the EP change to the value that was assigned at graph partitioning time + node->SetExecutionProviderType(ep_type); + } else { + kernel = info.CreateKernel(node, config_options_); + } + + // We currently constant fold using the CPU EP only. + // If we can't find a CPU kernel for this node, then we can't proceed with constant folding. + // + // TODO(adrianlizarraga): Support constant folding with other execution providers. For example, we may be able + // to use a CUDA kernel to constant fold operators with data types not supported by the CPU EP kernel. + if (kernel == nullptr) { + LOGS(logger, WARNING) << "Could not find a CPU kernel and hence " + << "can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; + + // Move on to the next candidate node + continue; + } + + OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); +#ifdef _WIN32 +#pragma warning(push) +#pragma warning(disable : 6387) +#endif + OpKernelContext op_kernel_context(&frame, kernel.get(), /*stream*/ nullptr, nullptr, logger); + ORT_RETURN_IF_ERROR(kernel->Compute(&op_kernel_context)); +#ifdef _WIN32 +#pragma warning(pop) +#endif + + std::vector fetches; + ORT_RETURN_IF_ERROR(frame.GetOutputs(fetches)); + + // Go over all output node args and substitute them with the newly computed tensors, which will be + // added to the graph as initializers. + ORT_ENFORCE(fetches.size() == node->OutputDefs().size()); + converted_to_constant = true; + for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { + const auto& constant_arg_out = *node->OutputDefs()[fetch_idx]; + // XXX: Add support for SparseTensors outputs when we have sparse outputs + if (!utils::HasTensorType(*constant_arg_out.TypeAsProto())) { + LOGS(logger, INFO) << "Unsupported output type of " << constant_arg_out.Type() + << ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; + converted_to_constant = false; + break; + } + } + + if (converted_to_constant) { + for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { + OrtValue& ort_value = fetches[fetch_idx]; + // Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph. + auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx]; + const Tensor& out_tensor = ort_value.Get(); + ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto(out_tensor, constant_arg_out->Name()); + + ONNX_NAMESPACE::TensorShapeProto result_shape; + for (auto& dim : out_tensor.Shape().GetDims()) { + result_shape.add_dim()->set_dim_value(dim); + } + + constant_arg_out->SetShape(result_shape); + graph.AddInitializedTensor(out_tensorproto); + } + } + } + + if (converted_to_constant) { + // Remove single-output node chain for inputs of the node + auto p_ip_node = node->InputNodesBegin(); + const auto p_ip_node_end = node->InputNodesEnd(); + while (p_ip_node != p_ip_node_end) { + const auto& input_node = *p_ip_node; + // Update the node iterator before removing the corresponding node because removing + // the node will invalidate the node iterator + ++p_ip_node; + graph_utils::RemoveNodesWithOneOutputBottomUp(graph, input_node); + } + + // Remove the output edges of the constant node and then remove the node itself. + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node->Index()); + modified = true; + have_updated_nodes = true; + } + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h new file mode 100644 index 0000000000000..3a9d8006ae099 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/constant_folding.h" +#include "core/framework/ort_value.h" +#include +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +/** +@class ConstantFolding + +Transformer that traverses the graph top-down and performs constant folding, i.e., +it statically computes parts of the graph that rely only on constant initializers. +*/ +class ConstantFoldingDQ : public ConstantFolding { + public: + /*! Constant folding will not be applied to nodes that have one of initializers from excluded_initializers as input. + For pre-training, the trainable weights are those initializers to be excluded. + \param execution_provider Execution provider instance to execute constant folding. + */ + ConstantFoldingDQ(const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& compatible_execution_providers = {}, + const InlinedHashSet& excluded_initializers = {}, + const InlinedHashSet& node_index_in_compute_capability = {}) noexcept; + + bool AllowConstantFolding(const Node&) const; + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + const InlinedHashSet& node_index_in_compute_capability_; + bool skip_dequantize_linear_; + const ConfigOptions& config_options_; + const InlinedHashSet excluded_initializers_; + const IExecutionProvider& execution_provider_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 5a179ec622f8c..9e8116177c3a7 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1221,6 +1221,8 @@ struct ProviderHost { virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0; #endif + virtual Status GetEPOptimizerByName(const std::string& optimizer_name, std::function>(const GraphViewer&)>& selection_func) = 0; + // ModelMetadefIdGenerator virtual std::unique_ptr ModelMetadefIdGenerator__construct() = 0; virtual void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) = 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index af39edae2074d..682f361d96df5 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -203,6 +203,29 @@ common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_na } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +Status ApplyConstantFoldingOnDQ(const Graph&, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) { + + return Status::OK(); +} + +std::vector> dq_nodes_to_constant_fold(const GraphViewer& graph_viewer) { + std::vector> result; + std::unique_ptr sub_graph = std::make_unique(); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); + for (const auto& index : node_index) { + const auto& node = graph_viewer.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + sub_graph->nodes.push_back(index); + std::cout << node->Name() << ", op type: " << node->OpType() << std::endl; + } + + result.push_back(std::make_unique(std::move(sub_graph))); + result.back()->optimization_func = ApplyConstantFoldingOnDQ; + return result; +} + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26436) @@ -1485,7 +1508,14 @@ struct ProviderHostImpl : ProviderHost { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); }; #endif + + Status GetEPOptimizerByName(const std::string& optimizer_name, std::function>(const GraphViewer&)>& selection_func) override { + selection_func = dq_nodes_to_constant_fold; + return Status::OK(); + }; + } provider_host_; + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif From e9119d513cb0445cc7613074ebeaa41b1588e719 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sat, 25 Jan 2025 21:28:49 -0800 Subject: [PATCH 02/24] include GraphTransformerManager to GetCapability --- .../core/framework/execution_provider.h | 10 ++++++ .../core/framework/execution_provider.cc | 3 +- .../core/framework/graph_partitioner.cc | 35 ++++++++++++++----- .../core/framework/graph_partitioner.h | 5 ++- .../providers/cuda/cuda_execution_provider.cc | 3 +- .../providers/cuda/cuda_execution_provider.h | 3 +- .../providers/shared_library/provider_api.h | 1 + .../provider_bridge_provider.cc | 5 +-- .../shared_library/provider_interfaces.h | 3 +- .../tensorrt/tensorrt_execution_provider.cc | 3 +- .../tensorrt/tensorrt_execution_provider.h | 3 +- onnxruntime/core/session/inference_session.cc | 7 ++-- .../core/session/provider_bridge_ort.cc | 9 +++-- .../test/framework/inference_session_test.cc | 3 +- .../test/framework/session_state_test.cc | 12 +++---- .../internal_testing_execution_provider.cc | 3 +- .../internal_testing_execution_provider.h | 3 +- 17 files changed, 79 insertions(+), 32 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 0d9e6db1a7748..c1113e9fe7af2 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -20,6 +20,7 @@ struct ComputeCapability; class KernelRegistry; struct KernelCreateInfo; class Node; +class GraphTransformerManager; } // namespace onnxruntime #else #include @@ -128,9 +129,18 @@ class IExecutionProvider { For kernels registered in a kernel registry, `kernel_lookup` must be used to find a matching kernel for this EP. */ + /* virtual std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup) const; + */ + + virtual std::vector> + GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& kernel_lookup, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const; + + virtual bool RequestCustomizedGraphOptimizationForEP() const { return false; } /** Get kernel registry per execution provider type. diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index b39924d4c3ff9..cae3a37e4146a 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -13,7 +13,8 @@ namespace onnxruntime { std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const { std::vector> result; for (const auto& node : graph.Nodes()) { if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node); diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index b97cf03e3bf59..a20225095e5b7 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -58,6 +58,7 @@ struct PartitionParams { std::reference_wrapper fused_node_unique_id; std::reference_wrapper transform_layout_function; std::reference_wrapper debug_graph_fn; + std::reference_wrapper graph_transformer_manager; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; } // namespace @@ -130,13 +131,20 @@ struct GetCapabilityForEPParams { GraphPartitioner::Mode mode; std::reference_wrapper transform_layout; std::reference_wrapper debug_graph_fn; + std::reference_wrapper graph_transformer_manager; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; auto get_capabilities = [](const IExecutionProvider& ep, const GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + const IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphTransformerManager& graph_transformer_manager) { + std::vector> capabilities; + if (ep.RequestCustomizedGraphOptimizationForEP()) { + capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_transformer_manager); + } else { + //capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + } // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), @@ -170,10 +178,11 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); + auto& graph_transformer_manager = params.graph_transformer_manager.get(); { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_manager); if (capabilities.empty()) { return Status::OK(); @@ -211,7 +220,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_manager); // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -248,6 +257,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l // It also does not perform layout transformation. This will be done during normal partitioning. static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, + const GraphTransformerManager& graph_transformer_mgr, const IExecutionProvider& current_ep, const logging::Logger& logger, std::vector>& capabilities) { @@ -260,7 +270,7 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, logger}; // TODO: Provide EP with a capability to look inside the functions. - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_mgr); return Status::OK(); } @@ -363,6 +373,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, + const onnxruntime::GraphTransformerManager& graph_transfomer_manager, const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability @@ -377,7 +388,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, logger)); + transform_layout_fn, debug_graph_fn, graph_transfomer_manager, logger)); } } @@ -400,7 +411,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::ref(capabilities), mode, std::cref(transform_layout_fn), - std::cref(debug_graph_fn)}; + std::cref(debug_graph_fn), + std::cref(graph_transfomer_manager)}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -562,6 +574,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, + const GraphTransformerManager& graph_transformer_mgr, Graph& graph, const logging::Logger& logger, InlinedHashSet& not_inlined, @@ -578,6 +591,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, + graph_transformer_mgr, *subgraph, logger, not_inlined, @@ -603,7 +617,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, graph_transformer_mgr, * ep, logger, capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; @@ -743,6 +757,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, auto& fused_kernel_registry = partition_params.fused_kernel_registry.get(); auto& fused_node_unique_id = partition_params.fused_node_unique_id.get(); const auto& transform_layout_function = partition_params.transform_layout_function; + const auto& graph_transformer_manager = partition_params.graph_transformer_manager; do { // process full graph with each EP @@ -751,6 +766,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, + graph_transformer_manager, logger)); } @@ -802,6 +818,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param GraphPartitioner::Mode::kOrtFormatLoad, std::cref(partition_params.transform_layout_function), std::cref(partition_params.debug_graph_fn), + std::cref(partition_params.graph_transformer_manager), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; // clang-format on @@ -917,6 +934,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, size_t inlined_count = 0; ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, + graph_transformer_mgr_, graph, logger, not_inlined, @@ -975,6 +993,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::ref(fused_node_unique_id), std::cref(transform_layout_function), std::cref(debug_graph_fn), + std::cref(graph_transformer_mgr_), }; #else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index d1ef193cf1520..8fcc1de782a6a 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -7,6 +7,7 @@ #include "core/graph/graph.h" #include "core/framework/fuse_nodes_funcs.h" #include "core/framework/transform_layout_functions.h" +#include "core/optimizer/graph_transformer_mgr.h" namespace onnxruntime { @@ -24,8 +25,9 @@ class GraphPartitioner { }; // The order of providers represents the user preference. - GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers) + GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const GraphTransformerManager& graph_transformer_mgr, const ExecutionProviders& providers) : kernel_registry_mgr_(kernel_registry_mgr), + graph_transformer_mgr_(graph_transformer_mgr), providers_(providers) { } @@ -64,6 +66,7 @@ class GraphPartitioner { KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; + const GraphTransformerManager& graph_transformer_mgr_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 4a10de153653c..5b923bcd31b49 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2658,7 +2658,8 @@ std::unique_ptr CUDAExecutionProvider::GetDataTransf std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + const GraphTransformerManager& graph_transformer_mgr) const { InlinedVector candidates; // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. InlinedVector tentative_nodes; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index bd2be2eac2181..b9f06e136ad17 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -72,7 +72,8 @@ class CUDAExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup) const override; + const IKernelLookup& kernel_lookup, + const GraphTransformerManager& graph_transformer_mgr) const override; int GetDeviceId() const override { return info_.device_id; } const cudaDeviceProp& GetDeviceProp() const { return device_prop_; }; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 45f81ed22b7f7..3d1f78e14fcdc 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -175,6 +175,7 @@ struct SparseTensor; class TensorSeq; class SessionState; class ModelMetadefIdGenerator; +class GraphTransformerManager; class If; class Loop; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index aa8c367d25d51..380be5bef7904 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -331,8 +331,9 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz } std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup) const { - return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup); + const IKernelLookup& kernel_lookup, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const { + return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, graph_transformer_mgr); } common::Status IExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 9e8116177c3a7..71a9d78d6363f 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -246,7 +246,8 @@ struct ProviderHost { // IExecutionProvider virtual std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) = 0; + const IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr) = 0; virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c583598bbcc52..709167764fdcc 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2451,7 +2451,8 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const { + const IKernelLookup&, /*kernel_lookup*/ + const GraphTransformerManager& graph_transformer_mgr) const { // Construct subgraph capability from node list std::vector> result; // Get ModelPath diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index d3e0b0fba8891..e92e898e96786 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -247,7 +247,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup&, /*kernel_lookup*/ + const GraphTransformerManager& graph_transformer_mgr) const override; int GetDeviceId() const { return device_id_; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 223eed248800e..1fe26f09bf98a 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1207,7 +1207,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // 7. insert copy nodes (required transformer). // Run Ahead Of time function inlining - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); + GraphPartitioner partitioner(kernel_registry_manager_, graph_transformer_mgr_, execution_providers_); if (const bool disable_aot_function_inlining = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1"; @@ -1598,6 +1598,7 @@ namespace { Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, + const onnxruntime::GraphTransformerManager& graph_transformer_manager, SessionState& session_state, const ConfigOptions& config_options, const logging::Logger& logger) { @@ -1617,7 +1618,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - GraphPartitioner partitioner(kernel_registry_manager, providers); + GraphPartitioner partitioner(kernel_registry_manager, graph_transformer_manager, providers); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, @@ -2052,7 +2053,7 @@ common::Status InferenceSession::Initialize() { "Loading anything other than ORT format models is not enabled in this build.")); #endif // !defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, + ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, graph_transformer_mgr_, *session_state_, session_options_.config_options, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 682f361d96df5..cda88ef9891c1 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -203,8 +203,10 @@ common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_na } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +//onnxruntime::GraphTransformerManager graph_transformer_mgr_(10 /*max_num_graph_transformation_steps*/); + Status ApplyConstantFoldingOnDQ(const Graph&, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) { - + //auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); return Status::OK(); } @@ -359,8 +361,9 @@ struct ProviderHostImpl : ProviderHost { // IExecutionProvider (direct) std::vector> IExecutionProvider__GetCapability( const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) override { - return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup); + const IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr) override { + return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, graph_transformer_mgr); } common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 740c566794f15..a94520077774f 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -137,7 +137,8 @@ class FuseExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override { + const IKernelLookup& /*kernel_lookup*/, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const override { // Fuse two add into one. std::vector> result; std::unique_ptr sub_graph = std::make_unique(); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index e7f8b1aaa49d8..a9760e62f217a 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -261,8 +261,8 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - - GraphPartitioner partitioner(krm, execution_providers); + GraphTransformerManager graph_transformer_mgr(10); + GraphPartitioner partitioner(krm, graph_transformer_mgr, execution_providers); ASSERT_STATUS_OK( partitioner.Partition( graph, session_state.GetMutableFuncMgr(), @@ -347,9 +347,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - + GraphTransformerManager graph_transformer_mgr(10); // Partition the graph - GraphPartitioner partitioner(krm, execution_providers); + GraphPartitioner partitioner(krm, graph_transformer_mgr, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, @@ -406,9 +406,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - + GraphTransformerManager graph_transformer_mgr(10); // Partition the graph - GraphPartitioner partitioner(krm, execution_providers); + GraphPartitioner partitioner(krm, graph_transformer_mgr, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 2e073def5d643..31049d2ab7e3c 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -110,7 +110,8 @@ DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const { std::vector> InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup) const { + const IKernelLookup& kernel_lookup, + const GraphTransformerManager& graph_transformer_mgr) const { // find nodes that have ops in our supported list std::unordered_set supported_static_nodes; std::unordered_set supported_compiled_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 6615eb82f2b05..3a193224d6309 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -19,7 +19,8 @@ class InternalTestingExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, - const IKernelLookup& /*kernel_lookup*/) const override; + const IKernelLookup& /*kernel_lookup*/, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const override; common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; From b7a0b7938e5a9583262f180fe38be9170aa1e501 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 26 Jan 2025 15:36:22 -0800 Subject: [PATCH 03/24] Add GraphTransformerManager for EP, optimization function and ComputeCapability, selection function and ComputeCapability --- .../core/optimizer/graph_transformer_utils.h | 6 + .../core/framework/compute_capability.h | 2 +- .../core/framework/graph_partitioner.cc | 3 + .../core/optimizer/constant_folding.cc | 15 +- onnxruntime/core/optimizer/constant_folding.h | 11 +- .../core/optimizer/graph_transformer_mgr.cc | 7 + .../core/optimizer/graph_transformer_mgr.h | 3 + .../core/optimizer/graph_transformer_utils.cc | 25 ++ .../constant_folding_dq_node.cc | 328 +----------------- .../constant_folding_dq_node.h | 8 +- .../shared_library/provider_interfaces.h | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 17 +- .../tensorrt/tensorrt_execution_provider.h | 2 + .../tensorrt_execution_provider_helper.cc | 13 + onnxruntime/core/session/inference_session.cc | 34 +- onnxruntime/core/session/inference_session.h | 6 + .../core/session/provider_bridge_ort.cc | 88 ++++- 17 files changed, 219 insertions(+), 353 deletions(-) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 31b0f22340510..e4a5b011dfeb6 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -58,6 +58,12 @@ InlinedVector> GenerateTransformers( concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); +InlinedVector> GenerateTransformersForEP( + TransformerLevel level, + const SessionOptions& session_options, + const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const logging::Logger& logger); + #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index de479ed029304..34f315878d4a8 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -28,7 +28,7 @@ struct ComputeCapability { // optional function to optimize this ComputeCapability // this will be called by ORT once the ComputeCapability is assigned to the EP // Optimization: std::function - std::function optimization_func; + std::function optimization_func; // optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized. // when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index a20225095e5b7..ba58cfe927963 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -140,11 +140,14 @@ auto get_capabilities = [](const IExecutionProvider& ep, const IExecutionProvider::IKernelLookup& kernel_lookup, const onnxruntime::GraphTransformerManager& graph_transformer_manager) { std::vector> capabilities; + capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_transformer_manager); + /* if (ep.RequestCustomizedGraphOptimizationForEP()) { capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_transformer_manager); } else { //capabilities = ep.GetCapability(graph_viewer, kernel_lookup); } + */ // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e755b4bfa6364..cb122c854d41e 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -28,6 +28,19 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider, execution_provider_(execution_provider) { } +ConstantFolding::ConstantFolding(const std::string& name, + const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& compatible_execution_providers, + const InlinedHashSet& excluded_initializers) noexcept + : GraphTransformer(name, compatible_execution_providers), + skip_dequantize_linear_(skip_dequantize_linear), + config_options_(config_options), + excluded_initializers_(excluded_initializers), + execution_provider_(execution_provider) { +} + // We need to handle a Shape node separately as the input doesn't need to be a constant initializer for // Shape to be able to be constant folded. static bool ConstantFoldShapeNode(Graph& graph, Node& node) { @@ -144,7 +157,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, for (NodeIndex i : order) { auto* node = graph.GetNode(i); - if (!node) { + if (!node || !AllowConstantFolding(*node)) { continue; } diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 4fe2658e89cda..7ce227f86fb04 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -28,7 +28,16 @@ class ConstantFolding : public GraphTransformer { const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; - virtual bool AllowConstantFolding(const Node&) const { return true; } + /* Same as above but with a name provided by derived class. + */ + ConstantFolding(const std::string& name, + const IExecutionProvider& execution_provider, + bool skip_dequantize_linear, + const ConfigOptions& config_options, + const InlinedHashSet& compatible_execution_providers = {}, + const InlinedHashSet& excluded_initializers = {}) noexcept; + + virtual bool AllowConstantFolding(const Node& node) const { return true; } private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index 039283bb2d4e1..ac7d3a74534d1 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -55,4 +55,11 @@ common::Status GraphTransformerManager::Register(std::unique_ptr GenerateRuleBasedGraphTransformer( return rule_transformer; } +InlinedVector> GenerateTransformersForEP( + TransformerLevel level, + const SessionOptions& session_options, + const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const logging::Logger& logger) { + InlinedVector> transformers; + switch (level) { + case TransformerLevel::Level1: { + break; + } + case TransformerLevel::Level2: { + transformers.emplace_back(std::make_unique(cpu_execution_provider, false /*skip_dequantize_linear*/, + session_options.config_options)); + break; + } + case TransformerLevel::Level3: { + break; + } + default: + ORT_THROW("Unsupported optimization level: ", static_cast(level)); + } + return transformers; +} + InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc index da31176f7a48c..395c431027c82 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc @@ -22,116 +22,8 @@ ConstantFoldingDQ::ConstantFoldingDQ(const IExecutionProvider& execution_provide const InlinedHashSet& compatible_execution_providers, const InlinedHashSet& excluded_initializers, const InlinedHashSet& node_index_in_compute_capability) noexcept - : ConstantFolding(execution_provider,skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers), - node_index_in_compute_capability_(node_index_in_compute_capability), - skip_dequantize_linear_(skip_dequantize_linear), - config_options_(config_options), - excluded_initializers_(excluded_initializers), - execution_provider_(execution_provider) { -} - -// We need to handle a Shape node separately as the input doesn't need to be a constant initializer for -// Shape to be able to be constant folded. -static bool ConstantFoldShapeNode(Graph& graph, Node& node) { - // Opset-15 Shape supports slicing using a 'start' and 'end' attribute - const auto& shape_attributes = node.GetAttributes(); - - int64_t start = 0; - int64_t end = std::numeric_limits::max(); - - for (const auto& attr : shape_attributes) { - if (attr.first == "start") { - start = attr.second.i(); - } else if (attr.first == "end") { - end = attr.second.i(); - } - } - - auto shape = node.MutableInputDefs()[0]->Shape(); - bool is_concrete_shape = true; - std::vector dim_values; - if (shape != nullptr) { - for (int dim_index = 0; dim_index < shape->dim_size(); dim_index++) { - auto dim = shape->dim(dim_index); - if (!utils::HasDimValue(dim)) { - is_concrete_shape = false; - break; - } - dim_values.push_back(dim.dim_value()); - } - } else { - is_concrete_shape = false; - } - - if (is_concrete_shape) { - int64_t rank = static_cast(dim_values.size()); - - // We ascertain the "true" starts/ends (if they were provided) - // Opset-15 Shape op supports slicing shape values - - // Deal with negatives and clamp - start = start < 0 ? start + rank : start; - start = start < 0 ? 0 : ((start > rank) ? rank : start); - - end = end < 0 ? end + rank : end; - end = end < 0 ? 0 : ((end > rank) ? rank : end); - - int64_t slice_length = end - start; - size_t clamped_slice_length = slice_length < 0 ? 0 : static_cast(slice_length); - - ONNX_NAMESPACE::TensorProto shape_constant; - auto* constant_arg_out = node.MutableOutputDefs()[0]; - shape_constant.set_name(constant_arg_out->Name()); - shape_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - shape_constant.add_dims(clamped_slice_length); - utils::SetRawDataInTensorProto(shape_constant, dim_values.data() + start, clamped_slice_length * sizeof(int64_t)); - ONNX_NAMESPACE::TensorShapeProto result_shape; - result_shape.add_dim()->set_dim_value(clamped_slice_length); - constant_arg_out->SetShape(result_shape); - graph.AddInitializedTensor(shape_constant); - } - - return is_concrete_shape; // convert to constant if this is true -} - -// This function inlines the appropriate subgraph. It does not literally fold it. -static Status ConstantFoldIfNode(Graph& graph, Node& if_node, const logging::Logger& logger, bool& folded) { - folded = false; - // First, find out which subgraph to inline - // We need to fetch the constant argument. - assert(if_node.InputDefs().size() == 1); - const auto* condition_def = if_node.InputDefs()[0]; - - // We need to check if the condition is a constant. - constexpr bool check_outer_scope_true = true; - const ONNX_NAMESPACE::TensorProto* initializer = - graph.GetConstantInitializer(condition_def->Name(), check_outer_scope_true); - if (initializer == nullptr) { - return Status::OK(); - } - - // This is a boolean initializer with a single element. - Initializer condition{*initializer}; - ORT_RETURN_IF_NOT(condition.size() == 1, "If node condition initializer: `", condition_def->Name(), - "' is expected to have a single boolean element"); - - const bool condition_value = *condition.data(); - - auto status = graph.InlineIfSubgraph(condition_value, if_node, logger); - - if (!status.IsOK()) { - LOGS(logger, WARNING) << "Unable to constant fold. InlineIfSubgraph failed " - << " node '" << if_node.Name() << "': " - << status.ErrorMessage(); - return status; - } - - graph_utils::RemoveNodeOutputEdges(graph, if_node); - graph.RemoveNode(if_node.Index()); - - folded = true; - return status; -} + : ConstantFolding("ConstantFoldingDQ", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers), + node_index_in_compute_capability_(node_index_in_compute_capability) {} bool ConstantFoldingDQ::AllowConstantFolding(const Node& node) const { if (node_index_in_compute_capability_.find(node.Index()) != node_index_in_compute_capability_.end()) { @@ -140,220 +32,4 @@ bool ConstantFoldingDQ::AllowConstantFolding(const Node& node) const { return false; } -Status ConstantFoldingDQ::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - bool have_updated_nodes = false; - GraphViewer graph_viewer(graph); - auto& order = graph_viewer.GetNodesInTopologicalOrder(); - -#if !defined(DISABLE_SPARSE_TENSORS) - std::function is_sparse_initializer_check = [&graph](const std::string& name) -> bool { - return graph.IsSparseInitializer(name); - }; -#endif - - for (NodeIndex i : order) { - auto* node = graph.GetNode(i); - if (!node) { - continue; - } - - if (!AllowConstantFolding(*node)) { - continue; - } - - ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); - - // Updating a node may allow shape inferencing to infer output shapes of following nodes, - // so re-run the shape inferencing. use have_updated_nodes as that only applies to this Graph - // (vs. 'modified' which is passed into subgraphs and applies to the main graph and all subgraphs) - // Ignore any control flow node containing subgraphs as UpdateShapeInference is not intended to be used on it. - if (have_updated_nodes && !node->ContainsSubgraph()) { - ORT_RETURN_IF_ERROR(graph.UpdateShapeInference(*node)); - } - - bool converted_to_constant = false; - if (node->OpType().compare("If") == 0) { - // This process constant folds the If node only, - // but inlines the nodes of the corresponding branch graph. - // It does not convert the node to a constant in a common sense. - // We call it constant folding because the `If` node constant condition - // may enable us to inline the corresponding branch graph. - bool folded = false; - ORT_RETURN_IF_ERROR(ConstantFoldIfNode(graph, *node, logger, folded)); - if (folded) { - // Node removal is done within ConstantFoldIfNode() - modified = true; - have_updated_nodes = true; - } - } else if (node->OpType().compare("Shape") == 0) { - converted_to_constant = ConstantFoldShapeNode(graph, *node); - } else { - InitializedTensorSet constant_inputs; - - // Check if constant folding can be applied on this node. - const auto can_constant_fold_node = [&](const Node& n, bool skip_inputs_constant_check = false) { - return graph_utils::IsSupportedProvider(n, GetCompatibleExecutionProviders()) && - optimizer_utils::IsOperationDeterministic(n.Domain(), n.OpType()) && - // constant folding does not support executing a node that includes subgraphs (control flow operators, - // such as If/Loop/Scan, fall into this category). individual nodes in the subgraph will be processed - // by the Recurse call above - !n.ContainsSubgraph() && - (skip_inputs_constant_check || - graph_utils::AllNodeInputsAreConstant(graph, n, constant_inputs, excluded_initializers_)); - }; - - if (!can_constant_fold_node(*node)) { - continue; - } - - // if skip_dequantize_linear is true we want to maintain QDQ node units so avoid constant folding - // DequantizeLinear unless we can fold the whole QDQ node unit - if (skip_dequantize_linear_ && node->OpType() == "DequantizeLinear") { - bool can_constant_fold_qdq_node_unit = false; - - // Simplest scenario where the whole QDQ node unit of (DQ -> X -> Q) can be constant folded is if: - // - the DQ node does not produce a graph output, and its output is only consumed by X - // - X is a deterministic node with a single input and single output - // - the output from X is not a graph output and is only consumed by a Q node - if (optimizer_utils::CheckOutputEdges(graph, *node, 1)) { // DQ does not produce graph output, single consumer - const Node& node_x = *node->OutputNodesBegin(); - if (node_x.InputDefs().size() == 1 && - node_x.OutputDefs().size() == 1 && - optimizer_utils::CheckOutputEdges(graph, node_x, 1)) { - const Node& probably_q = *node_x.OutputNodesBegin(); - - if (probably_q.OpType() == "QuantizeLinear") { - // the inputs to these nodes are not const yet, but will be if we constant fold, - // so set skip_const_check to simulate that having happened - constexpr bool skip_const_check = true; - can_constant_fold_qdq_node_unit = can_constant_fold_node(node_x, skip_const_check) && - can_constant_fold_node(probably_q, skip_const_check); - } - } - } - - if (!can_constant_fold_qdq_node_unit) { - continue; - } - } - -#if !defined(DISABLE_SPARSE_TENSORS) - // Create execution frame for executing constant nodes. - OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - is_sparse_initializer_check, logger); -#else - // Create execution frame for executing constant nodes. - OptimizerExecutionFrame::Info info( - {node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; }, - logger); -#endif - - std::vector fetch_mlvalue_idxs; - for (const auto* node_out : node->OutputDefs()) { - fetch_mlvalue_idxs.push_back(info.GetMLValueIndex(node_out->Name())); - } - - const bool node_on_cpu_ep = node->GetExecutionProviderType() == kCpuExecutionProvider; - - std::unique_ptr kernel; - - if (!node_on_cpu_ep) { - // We need to copy the string here instead of taking a reference to it since node->SetExecutionProviderType - // will change the value of the reference - auto ep_type = node->GetExecutionProviderType(); - - // override the EP assigned to the node so that it will use the CPU kernel for Compute. - node->SetExecutionProviderType(kCpuExecutionProvider); - - kernel = info.CreateKernel(node, config_options_); - - // undo the EP change to the value that was assigned at graph partitioning time - node->SetExecutionProviderType(ep_type); - } else { - kernel = info.CreateKernel(node, config_options_); - } - - // We currently constant fold using the CPU EP only. - // If we can't find a CPU kernel for this node, then we can't proceed with constant folding. - // - // TODO(adrianlizarraga): Support constant folding with other execution providers. For example, we may be able - // to use a CUDA kernel to constant fold operators with data types not supported by the CPU EP kernel. - if (kernel == nullptr) { - LOGS(logger, WARNING) << "Could not find a CPU kernel and hence " - << "can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; - - // Move on to the next candidate node - continue; - } - - OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); -#ifdef _WIN32 -#pragma warning(push) -#pragma warning(disable : 6387) -#endif - OpKernelContext op_kernel_context(&frame, kernel.get(), /*stream*/ nullptr, nullptr, logger); - ORT_RETURN_IF_ERROR(kernel->Compute(&op_kernel_context)); -#ifdef _WIN32 -#pragma warning(pop) -#endif - - std::vector fetches; - ORT_RETURN_IF_ERROR(frame.GetOutputs(fetches)); - - // Go over all output node args and substitute them with the newly computed tensors, which will be - // added to the graph as initializers. - ORT_ENFORCE(fetches.size() == node->OutputDefs().size()); - converted_to_constant = true; - for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { - const auto& constant_arg_out = *node->OutputDefs()[fetch_idx]; - // XXX: Add support for SparseTensors outputs when we have sparse outputs - if (!utils::HasTensorType(*constant_arg_out.TypeAsProto())) { - LOGS(logger, INFO) << "Unsupported output type of " << constant_arg_out.Type() - << ". Can't constant fold " << node->OpType() << " node '" << node->Name() << "'"; - converted_to_constant = false; - break; - } - } - - if (converted_to_constant) { - for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { - OrtValue& ort_value = fetches[fetch_idx]; - // Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph. - auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx]; - const Tensor& out_tensor = ort_value.Get(); - ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto(out_tensor, constant_arg_out->Name()); - - ONNX_NAMESPACE::TensorShapeProto result_shape; - for (auto& dim : out_tensor.Shape().GetDims()) { - result_shape.add_dim()->set_dim_value(dim); - } - - constant_arg_out->SetShape(result_shape); - graph.AddInitializedTensor(out_tensorproto); - } - } - } - - if (converted_to_constant) { - // Remove single-output node chain for inputs of the node - auto p_ip_node = node->InputNodesBegin(); - const auto p_ip_node_end = node->InputNodesEnd(); - while (p_ip_node != p_ip_node_end) { - const auto& input_node = *p_ip_node; - // Update the node iterator before removing the corresponding node because removing - // the node will invalidate the node iterator - ++p_ip_node; - graph_utils::RemoveNodesWithOneOutputBottomUp(graph, input_node); - } - - // Remove the output edges of the constant node and then remove the node itself. - graph_utils::RemoveNodeOutputEdges(graph, *node); - graph.RemoveNode(node->Index()); - modified = true; - have_updated_nodes = true; - } - } - - return Status::OK(); -} } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h index 3a9d8006ae099..8dd588d3ca2ce 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h @@ -30,16 +30,10 @@ class ConstantFoldingDQ : public ConstantFolding { const InlinedHashSet& excluded_initializers = {}, const InlinedHashSet& node_index_in_compute_capability = {}) noexcept; - bool AllowConstantFolding(const Node&) const; + bool AllowConstantFolding(const Node& node) const; private: - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - const InlinedHashSet& node_index_in_compute_capability_; - bool skip_dequantize_linear_; - const ConfigOptions& config_options_; - const InlinedHashSet excluded_initializers_; - const IExecutionProvider& execution_provider_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 71a9d78d6363f..e75420dee794e 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1222,7 +1222,9 @@ struct ProviderHost { virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0; #endif - virtual Status GetEPOptimizerByName(const std::string& optimizer_name, std::function>(const GraphViewer&)>& selection_func) = 0; + virtual Status GetEPOptimizerByName(const std::string& optimizer_name, + const GraphTransformerManager& graph_transformer_mgr, + std::function>(const GraphViewer&)>& selection_func) = 0; // ModelMetadefIdGenerator virtual std::unique_ptr ModelMetadefIdGenerator__construct() = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 709167764fdcc..60f2db7800786 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2555,7 +2555,8 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } bool early_termination = false; - supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); + //supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); + supported_nodes_vector = parser_nodes_vector; if (early_termination) { supported_nodes_vector.clear(); } @@ -2656,11 +2657,23 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } } + std::function>(const GraphViewer&)> selection_func; + auto status = g_host->GetEPOptimizerByName("ConstantFoldingDQ", graph_transformer_mgr, selection_func); + auto optimizer_cc = selection_func(graph); + + std::unordered_map consumer_to_dq; + CreateConsumerToDqMap(graph, consumer_to_dq); + + int number_of_trt_nodes = 0, subgraph_index = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); + auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); + + + result.push_back(std::move(compute_capability)); + //result.push_back(ComputeCapability::Create(std::move(sub_graph))); number_of_trt_nodes += static_cast(group.first.size()); subgraph_index++; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index e92e898e96786..39dc588ae8459 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -591,5 +591,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { * This function only creates the instance at the first time it's being called." */ nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; + + void CreateConsumerToDqMap(const GraphViewer& graph, std::unordered_map& map) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index 92fa101118506..523f6a544f619 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -258,4 +258,17 @@ void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { graph.SetInputs(graph_inputs_including_initializers); } + +void TensorrtExecutionProvider::CreateConsumerToDqMap(const GraphViewer& graph, std::unordered_map& map) const { + LOGS_DEFAULT(VERBOSE) << "Create consumer node to DQ node map ..."; + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + for (auto index : node_index) { + auto* node = graph.GetNode(index); + if (node->OpType() == "DequantizeLinear" && node->GetOutputEdgesCount() == 1) { // DQ does not produce graph output, single consumer + const Node& consumer_node = *node->OutputNodesBegin(); + map[consumer_node.Index()] = index; + LOGS_DEFAULT(VERBOSE) << consumer_node.Name() << " <- " << node->Name(); + } + } +} } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 1fe26f09bf98a..f1020baf7114e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -646,6 +646,7 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const : #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), + ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif environment_(session_env) { // Initialize assets of this session instance @@ -659,6 +660,7 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, : #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), + ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif external_intra_op_thread_pool_(external_intra_op_thread_pool), external_inter_op_thread_pool_(external_inter_op_thread_pool), @@ -672,6 +674,7 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const const PathString& model_uri) : model_location_(model_uri), graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), + ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), environment_(session_env) { auto status = Model::Load(model_location_, model_proto_); ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ", @@ -692,6 +695,7 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, std::istream& model_istream) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), + ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), environment_(session_env) { Status st = Model::Load(model_istream, &model_proto_); ORT_ENFORCE(st.IsOK(), "Could not parse model successfully while constructing the inference session"); @@ -703,6 +707,7 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, const void* model_data, int model_data_len) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), + ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), environment_(session_env) { const bool result = model_proto_.ParseFromArray(model_data, model_data_len); ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); @@ -1207,7 +1212,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // 7. insert copy nodes (required transformer). // Run Ahead Of time function inlining - GraphPartitioner partitioner(kernel_registry_manager_, graph_transformer_mgr_, execution_providers_); + GraphPartitioner partitioner(kernel_registry_manager_, ep_graph_transformer_mgr_, execution_providers_); if (const bool disable_aot_function_inlining = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1"; @@ -1845,6 +1850,11 @@ common::Status InferenceSession::Initialize() { record_runtime_optimization_produced_op_schema, *session_logger_)); + // add predefined transformers for EP + ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformersForEP(ep_graph_transformer_mgr_, + session_options_.graph_optimization_level, + *session_logger_)); + #ifdef USE_DML const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); @@ -3275,6 +3285,28 @@ common::Status InferenceSession::AddPredefinedTransformers( return Status::OK(); } +// Registers all the predefined transformers for EP with transformer manager +common::Status InferenceSession::AddPredefinedTransformersForEP( + GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + const logging::Logger& logger) const { + const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); + for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { + TransformerLevel level = static_cast(i); + if (graph_optimization_level >= level) { + // Generate and register transformers for level + auto transformers_to_register = [&]() { + return optimizer_utils::GenerateTransformersForEP(level, session_options_, cpu_ep, logger); + }(); + + for (auto& entry : transformers_to_register) { + ORT_RETURN_IF_ERROR(transformer_manager.Register(std::move(entry), level)); + } + } + } + return Status::OK(); +} + #endif // !defined(ORT_MINIMAL_BUILD) common::Status InferenceSession::WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index e28ff75345785..af5f7384a5d79 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -712,9 +712,15 @@ class InferenceSession { RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, const logging::Logger& logger) const; + virtual common::Status AddPredefinedTransformersForEP( + GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + const logging::Logger& logger) const; + common::Status TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format); onnxruntime::GraphTransformerManager graph_transformer_mgr_; + onnxruntime::GraphTransformerManager ep_graph_transformer_mgr_; InlinedHashSet> saved_runtime_optimization_produced_node_op_schemas_; #endif diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index cda88ef9891c1..b8015049de866 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -203,17 +203,16 @@ common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_na } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -//onnxruntime::GraphTransformerManager graph_transformer_mgr_(10 /*max_num_graph_transformation_steps*/); - -Status ApplyConstantFoldingOnDQ(const Graph&, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) { - //auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); +/* +Status ApplyConstantFoldingDQ(const Graph&, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) { + auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); return Status::OK(); } -std::vector> dq_nodes_to_constant_fold(const GraphViewer& graph_viewer) { +std::vector> ConstantFoldingDQ(const GraphViewer& graph_viewer) { std::vector> result; std::unique_ptr sub_graph = std::make_unique(); - const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); for (const auto& index : node_index) { const auto& node = graph_viewer.GetNode(index); if (node->OpType() != "DequantizeLinear") { @@ -224,10 +223,25 @@ std::vector> dq_nodes_to_constant_fold(const } result.push_back(std::make_unique(std::move(sub_graph))); - result.back()->optimization_func = ApplyConstantFoldingOnDQ; + result.back()->optimization_func = ApplyConstantFoldingDQ; return result; } +Status GetPredefinedEPGraphTransformersForLookUp(std::unordered_map>(const GraphViewer&)>>& map) { + static const std::string kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ = "ConstantFoldingDQ"; + static std::unordered_map>(const GraphViewer&)>> ep_transformers_map; + + if (ep_transformers_map.find(kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ) == ep_transformers_map.end()) { + ep_transformers_map[kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ] = ConstantFoldingDQ; + } + + map = ep_transformers_map; + return Status::OK(); +} +*/ + +const GraphTransformerManager* graph_transformer_manager; + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26436) @@ -238,6 +252,60 @@ std::vector> dq_nodes_to_constant_fold(const struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } + Status GetEPOptimizerByName(const std::string& name, + const GraphTransformerManager& graph_transformer_mgr, + std::function>(const GraphViewer&)>& selection_func) override { + std::string optimizer_name(name); + + // pre-defined graph transformers/optimizers + static const std::string kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ = "ConstantFoldingDQ"; + + // optimization function of constant folding dq + auto constant_folding_dq_optimization = [&](Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) -> Status { + auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); + auto transformer = graph_transformer_mgr.GetTransformerByName(optimizer_name); + bool graph_changed = false; + bool modified = false; + + auto status = transformer->Apply(graph, modified, *logger); + graph_changed = graph_changed || modified; + + return Status::OK(); + }; + + // selection function of constant folding dq + auto constant_folding_dq_selection = [&](const GraphViewer& graph_viewer) -> std::vector> { + std::vector> result; + std::unique_ptr sub_graph = std::make_unique(); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); + for (const auto& index : node_index) { + const auto& node = graph_viewer.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + sub_graph->nodes.push_back(index); + std::cout << node->Name() << ", op type: " << node->OpType() << std::endl; + } + + result.push_back(std::make_unique(std::move(sub_graph))); + result.back()->optimization_func = constant_folding_dq_optimization; + return result; + }; + + // optimizer lookup table + static std::unordered_map>(const GraphViewer&)>> ep_transformers_map; + if (ep_transformers_map.find(kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ) == ep_transformers_map.end()) { + ep_transformers_map[kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ] = constant_folding_dq_selection; + } + + + auto transformer = graph_transformer_mgr.GetTransformerByName(optimizer_name); + if (transformer) { + selection_func = ep_transformers_map[optimizer_name]; + } + return Status::OK(); + }; + void* HeapAllocate(size_t size) override { return new uint8_t[size]; } void HeapFree(void* p) override { delete[] reinterpret_cast(p); } @@ -1511,12 +1579,6 @@ struct ProviderHostImpl : ProviderHost { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) Status LoadDynamicLibrary(onnxruntime::PathString library_name) override { return LoadDynamicLibraryFromProvider(library_name); }; #endif - - Status GetEPOptimizerByName(const std::string& optimizer_name, std::function>(const GraphViewer&)>& selection_func) override { - selection_func = dq_nodes_to_constant_fold; - return Status::OK(); - }; - } provider_host_; #if defined(_MSC_VER) && !defined(__clang__) From 3b28ffc4bce4c91820f492d560ce161960b0615e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 27 Jan 2025 17:39:05 -0800 Subject: [PATCH 04/24] refine GraphTransformerManager for EP, optimization function and ComputeCapability, selection function and ComputeCapability --- .../core/framework/graph_partitioner.cc | 25 ++++- .../core/optimizer/graph_transformer_mgr.cc | 21 ++++ .../core/optimizer/graph_transformer_mgr.h | 4 + .../core/optimizer/graph_transformer_utils.cc | 3 +- .../constant_folding_dq_node.cc | 13 ++- .../constant_folding_dq_node.h | 11 ++- .../shared_library/provider_interfaces.h | 9 +- .../shared_library/provider_wrappedtypes.h | 2 + .../tensorrt/tensorrt_execution_provider.cc | 10 +- .../core/session/provider_bridge_ort.cc | 97 ++++++++++++++++--- 10 files changed, 161 insertions(+), 34 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index ba58cfe927963..bc84e9ab9289a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -293,11 +293,14 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, IExecutionProvider::FusionStyle fusion_style, const std::string& provider_type, GraphPartitioner::Mode mode, - int& fused_node_unique_id) { + int& fused_node_unique_id, + bool* subgraph_assigned_to_ep) { Node* result = nullptr; + *subgraph_assigned_to_ep = false; if (nullptr == capability.GetMetaDef()) { TryAssignSingleNode(graph, capability, provider_type); + *subgraph_assigned_to_ep = true; } else { // The can run a fused in the . @@ -360,12 +363,17 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, } } } + *subgraph_assigned_to_ep = true; } } return result; } +static Status TransformGraph(Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) { + +} + // for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up). // assign any nodes to the EP that are currently unassigned, and that the EP can handle. static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, @@ -441,7 +449,20 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, entry->sub_graph->GetMetaDef() != nullptr; })); for (auto& capability : capabilities) { - Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); + bool subgraph_assigned_to_ep = false; + Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id, &subgraph_assigned_to_ep); + + // If the subgraph is assigned to the ep and the ComputeCapability has nodes_to_optimize, + // run EP related optimizations and update compute capability (cc). + if (subgraph_assigned_to_ep && !capability->nodes_to_optimize.empty()) { + for (auto& optimization_cc : capability->nodes_to_optimize) { + if (optimization_cc->optimization_func) { + optimization_cc->optimization_func(graph, *optimization_cc, *capability); + // #TODO: Handle nested optimization func? + } + } + } + if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) { diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index ac7d3a74534d1..f56bc3bfab15f 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -44,6 +44,27 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor return Status::OK(); } +common::Status GraphTransformerManager::ApplyTransformer(Graph& graph, std::string& name, + const logging::Logger& logger) const { + auto transformer = GetTransformerByName(name); + if (!transformer) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer is not registered " + name); + } + + bool modified = false; + for (unsigned step = 0; step < steps_; ++step) { + if (step > 0 && transformer->ShouldOnlyApplyOnce()) { + break; + } + ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); + if (!modified) { + break; + } + } + + return Status::OK(); +} + common::Status GraphTransformerManager::Register(std::unique_ptr transformer, TransformerLevel level) { const auto& name = transformer->Name(); diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h index 8d1bde9cb608c..f2a6b77f5fc39 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.h +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h @@ -30,6 +30,10 @@ class GraphTransformerManager { // Apply all transformers registered for the given level on the given graph common::Status ApplyTransformers(Graph& graph, TransformerLevel level, const logging::Logger& logger) const; + // Apply one transformer registered by name on the given graph + common::Status GraphTransformerManager::ApplyTransformer(Graph& graph, std::string& name, + const logging::Logger& logger) const; + // Get transformer by name. Return nullptr if not found. GraphTransformer* GetTransformerByName(std::string& name) const; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 2a45a74871210..854d8847ea7c6 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -198,8 +198,9 @@ InlinedVector> GenerateTransformersForEP( break; } case TransformerLevel::Level2: { + const InlinedHashSet node_index_set = {}; transformers.emplace_back(std::make_unique(cpu_execution_provider, false /*skip_dequantize_linear*/, - session_options.config_options)); + session_options.config_options, node_index_set)); break; } case TransformerLevel::Level3: { diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc index 395c431027c82..9bb09b542abbf 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc @@ -19,17 +19,22 @@ namespace onnxruntime { ConstantFoldingDQ::ConstantFoldingDQ(const IExecutionProvider& execution_provider, bool skip_dequantize_linear, const ConfigOptions& config_options, + const InlinedHashSet& node_index_set, const InlinedHashSet& compatible_execution_providers, - const InlinedHashSet& excluded_initializers, - const InlinedHashSet& node_index_in_compute_capability) noexcept + const InlinedHashSet& excluded_initializers) noexcept : ConstantFolding("ConstantFoldingDQ", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers), - node_index_in_compute_capability_(node_index_in_compute_capability) {} + node_index_set_(node_index_set) {} bool ConstantFoldingDQ::AllowConstantFolding(const Node& node) const { - if (node_index_in_compute_capability_.find(node.Index()) != node_index_in_compute_capability_.end()) { + if (node_index_set_.find(node.Index()) != node_index_set_.end()) { return true; } return false; } +Status ConstantFoldingDQ::UpdateNodeIndexSet(InlinedHashSet& node_index_set) { + node_index_set_ = node_index_set; + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h index 8dd588d3ca2ce..6b2f8aabdaf3e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h @@ -25,15 +25,16 @@ class ConstantFoldingDQ : public ConstantFolding { */ ConstantFoldingDQ(const IExecutionProvider& execution_provider, bool skip_dequantize_linear, - const ConfigOptions& config_options, + const ConfigOptions& config_options, + const InlinedHashSet& node_index_set, const InlinedHashSet& compatible_execution_providers = {}, - const InlinedHashSet& excluded_initializers = {}, - const InlinedHashSet& node_index_in_compute_capability = {}) noexcept; + const InlinedHashSet& excluded_initializers = {}) noexcept; - bool AllowConstantFolding(const Node& node) const; + bool AllowConstantFolding(const Node& node) const; + Status UpdateNodeIndexSet(InlinedHashSet& node_index_set); private: - const InlinedHashSet& node_index_in_compute_capability_; + InlinedHashSet node_index_set_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index e75420dee794e..2b993c47bf52c 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -142,6 +142,10 @@ struct Node__EdgeIterator { struct ProviderHost { virtual const OrtApiBase* OrtGetApiBase() = 0; + virtual Status GetEPOptimizerByName(const std::string& name, + const GraphTransformerManager& graph_transformer_mgr, + std::function>(const GraphViewer&)>& selection_func) = 0; + virtual void* HeapAllocate(size_t size) = 0; virtual void HeapFree(void*) = 0; @@ -595,6 +599,7 @@ struct ProviderHost { virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; virtual std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) = 0; + virtual void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) = 0; // DataTransferManager virtual Status DataTransferManager__CopyTensor(const DataTransferManager* p, const Tensor& src, Tensor& dst) = 0; @@ -1222,10 +1227,6 @@ struct ProviderHost { virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0; #endif - virtual Status GetEPOptimizerByName(const std::string& optimizer_name, - const GraphTransformerManager& graph_transformer_mgr, - std::function>(const GraphViewer&)>& selection_func) = 0; - // ModelMetadefIdGenerator virtual std::unique_ptr ModelMetadefIdGenerator__construct() = 0; virtual void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 76b6d8063fd66..2d896427ec74e 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -499,6 +499,8 @@ struct ComputeCapability final { std::unique_ptr& SubGraph() { return g_host->ComputeCapability__SubGraph(this); } + void add_nodes_to_optimize(std::unique_ptr optimization_cc) { g_host->ComputeCapability__add_nodes_to_optimize(this, std::move(optimization_cc)); } + ComputeCapability() = delete; ComputeCapability(const ComputeCapability&) = delete; void operator=(const ComputeCapability&) = delete; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 60f2db7800786..fbe39f413079d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2657,20 +2657,26 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } } + // Enable EP related L2+ graph optimizations: + // 1. Dequantize INT32, UINT16, INT16 constant to FP32 -> Apply constant folding on DQ nodes std::function>(const GraphViewer&)> selection_func; auto status = g_host->GetEPOptimizerByName("ConstantFoldingDQ", graph_transformer_mgr, selection_func); - auto optimizer_cc = selection_func(graph); + auto optimization_cc = selection_func(graph); std::unordered_map consumer_to_dq; CreateConsumerToDqMap(graph, consumer_to_dq); - + // Create compute capability int number_of_trt_nodes = 0, subgraph_index = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); + // update compute capability to add node_to_optimize + for (auto& cc : optimization_cc) { + compute_capability->add_nodes_to_optimize(std::move(cc)); + } result.push_back(std::move(compute_capability)); //result.push_back(ComputeCapability::Create(std::move(sub_graph))); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index b8015049de866..7bc9f5cf53fe1 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -37,6 +37,7 @@ #include "core/framework/model_metadef_id_generator.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_c_api.h" @@ -253,38 +254,101 @@ struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } Status GetEPOptimizerByName(const std::string& name, - const GraphTransformerManager& graph_transformer_mgr, + const GraphTransformerManager& transformer_mgr, std::function>(const GraphViewer&)>& selection_func) override { + static const GraphTransformerManager& graph_transformer_mgr = transformer_mgr; std::string optimizer_name(name); // pre-defined graph transformers/optimizers static const std::string kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ = "ConstantFoldingDQ"; - // optimization function of constant folding dq + // ConstantFoldingDQ's optimization function auto constant_folding_dq_optimization = [&](Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) -> Status { + std::string optimizer_name = kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ; auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); - auto transformer = graph_transformer_mgr.GetTransformerByName(optimizer_name); - bool graph_changed = false; - bool modified = false; + std::unordered_set original_initializers_to_remove; + std::unordered_set new_initializers_to_add; + InlinedHashSet dq_node_index_set; + + // iterate node_to_optimize to: + // 1. get original initializers to remove + // 2. add new initializers + // 3. create dq node index set + for (const auto& index : this_optimization.sub_graph->nodes) { + auto node = graph.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + auto input_0 = node->InputDefs()[0]; + auto output_0 = node->OutputDefs()[0]; + original_initializers_to_remove.insert(input_0->Name()); + new_initializers_to_add.insert(output_0->Name()); + dq_node_index_set.insert(index); + } + + ConstantFoldingDQ* transformer = static_cast(graph_transformer_mgr.GetTransformerByName(optimizer_name)); + transformer->UpdateNodeIndexSet(dq_node_index_set); + + // apply constant folding on DQ nodes + graph_transformer_mgr.ApplyTransformer(graph, optimizer_name, *logger); + + // update the overall ComputeCapability + std::vector updated_nodes; + for (auto index : cc_to_update.sub_graph->nodes) { + if (dq_node_index_set.find(index) != dq_node_index_set.end()) { + continue; + } + updated_nodes.push_back(index); + } + cc_to_update.sub_graph->nodes = updated_nodes; + + auto original_meta_def = cc_to_update.sub_graph->GetMetaDef(); + std::unique_ptr updated_meta_def = std::make_unique(); + updated_meta_def->name = original_meta_def->name; + updated_meta_def->domain = original_meta_def->domain; + updated_meta_def->since_version = original_meta_def->since_version; + updated_meta_def->status = original_meta_def->status; + updated_meta_def->inputs = original_meta_def->inputs; + updated_meta_def->outputs = original_meta_def->outputs; + updated_meta_def->attributes = original_meta_def->attributes; + updated_meta_def->doc_string = original_meta_def->doc_string; +#if !defined(ORT_MINIMAL_BUILD) + updated_meta_def->type_and_shape_inference_function = original_meta_def->type_and_shape_inference_function; +#endif + for (auto constant_initializer : original_meta_def->constant_initializers) { + if (original_initializers_to_remove.find(constant_initializer) != original_initializers_to_remove.end()) { + continue; + } + updated_meta_def->constant_initializers.push_back(constant_initializer); + } + + for (auto constant_initializer : new_initializers_to_add) { + updated_meta_def->constant_initializers.push_back(constant_initializer); + } - auto status = transformer->Apply(graph, modified, *logger); - graph_changed = graph_changed || modified; + cc_to_update.sub_graph->SetMetaDef(std::move(updated_meta_def)); return Status::OK(); }; - // selection function of constant folding dq + // ConstantFoldingDQ's selection function auto constant_folding_dq_selection = [&](const GraphViewer& graph_viewer) -> std::vector> { std::vector> result; std::unique_ptr sub_graph = std::make_unique(); const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); + InitializedTensorSet constant_inputs; + const InlinedHashSet excluded_initializers; + + // Select DequantizeLinear node which dequantizes the bias/constant of Conv, Gemm, LayerNormalization node ... (i.e. initializer -> DQ -> bias of X): for (const auto& index : node_index) { const auto& node = graph_viewer.GetNode(index); if (node->OpType() != "DequantizeLinear") { continue; } + if (!graph_utils::AllNodeInputsAreConstant(graph_viewer.GetGraph(), *node, constant_inputs, excluded_initializers)) { + continue; + } sub_graph->nodes.push_back(index); - std::cout << node->Name() << ", op type: " << node->OpType() << std::endl; } result.push_back(std::make_unique(std::move(sub_graph))); @@ -293,15 +357,15 @@ struct ProviderHostImpl : ProviderHost { }; // optimizer lookup table - static std::unordered_map>(const GraphViewer&)>> ep_transformers_map; - if (ep_transformers_map.find(kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ) == ep_transformers_map.end()) { - ep_transformers_map[kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ] = constant_folding_dq_selection; + static std::unordered_map>(const GraphViewer&)>> optimizer_to_selection_function; + if (optimizer_to_selection_function.find(kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ) == optimizer_to_selection_function.end()) { + optimizer_to_selection_function[kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ] = constant_folding_dq_selection; } - - auto transformer = graph_transformer_mgr.GetTransformerByName(optimizer_name); - if (transformer) { - selection_func = ep_transformers_map[optimizer_name]; + // auto transformer = graph_transformer_mgr->GetTransformerByName(optimizer_name); + auto look_up = optimizer_to_selection_function.find(optimizer_name); + if (look_up != optimizer_to_selection_function.end()) { + selection_func = optimizer_to_selection_function[optimizer_name]; } return Status::OK(); }; @@ -819,6 +883,7 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) override { return p->sub_graph; } + void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) override { p->nodes_to_optimize.push_back(std::move(optimization_cc)); } // DataTransferManager (wrapped) Status DataTransferManager__CopyTensor(const DataTransferManager* p, const Tensor& src, Tensor& dst) override { return p->CopyTensor(src, dst); } From 309341e86c999a0b5cc630c46b579f29af6a6bca Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 27 Jan 2025 22:38:34 -0800 Subject: [PATCH 05/24] TRT EP creates optimization compute capability --- .../shared_library/provider_interfaces.h | 1 + .../shared_library/provider_wrappedtypes.h | 1 + .../tensorrt/tensorrt_execution_provider.cc | 44 +++++++++++++++---- .../tensorrt/tensorrt_execution_provider.h | 5 ++- .../tensorrt_execution_provider_helper.cc | 29 ++++++++++-- .../core/session/provider_bridge_ort.cc | 8 ++-- 6 files changed, 71 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 2b993c47bf52c..76c650276642e 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -599,6 +599,7 @@ struct ProviderHost { virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; virtual std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) = 0; + virtual void ComputeCapability__copy_optimization_func(ComputeCapability* p, ComputeCapability* selection_cc) = 0; virtual void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) = 0; // DataTransferManager diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 2d896427ec74e..6146c2e98bd62 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -499,6 +499,7 @@ struct ComputeCapability final { std::unique_ptr& SubGraph() { return g_host->ComputeCapability__SubGraph(this); } + void copy_optimization_func(ComputeCapability* selection_cc) { g_host->ComputeCapability__copy_optimization_func(this, selection_cc); } void add_nodes_to_optimize(std::unique_ptr optimization_cc) { g_host->ComputeCapability__add_nodes_to_optimize(this, std::move(optimization_cc)); } ComputeCapability() = delete; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fbe39f413079d..452fa545aef6a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2658,13 +2658,17 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } // Enable EP related L2+ graph optimizations: - // 1. Dequantize INT32, UINT16, INT16 constant to FP32 -> Apply constant folding on DQ nodes + // 1. Dequantize INT32, UINT16, INT16 constant to FP32 (Apply constant folding on DQ nodes) std::function>(const GraphViewer&)> selection_func; auto status = g_host->GetEPOptimizerByName("ConstantFoldingDQ", graph_transformer_mgr, selection_func); - auto optimization_cc = selection_func(graph); + std::vector> selection_cc; + if (selection_func) { + selection_cc = selection_func(graph); + } - std::unordered_map consumer_to_dq; - CreateConsumerToDqMap(graph, consumer_to_dq); + std::unordered_set trt_selection_node_set; + std::unordered_map consumer_to_dq; // consumer node -> dq node + CreateConsumerToDqMap(graph, trt_selection_node_set, consumer_to_dq); // Create compute capability int number_of_trt_nodes = 0, subgraph_index = 0; @@ -2673,13 +2677,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); - // update compute capability to add node_to_optimize - for (auto& cc : optimization_cc) { - compute_capability->add_nodes_to_optimize(std::move(cc)); + // add optimization compute capability to node_to_optimize + for (auto& cc : selection_cc) { + std::unique_ptr optimization_cc = CreateOptimizationComputeCapability(cc.get(), trt_selection_node_set, compute_capability.get()); + compute_capability->add_nodes_to_optimize(std::move(optimization_cc)); } result.push_back(std::move(compute_capability)); - //result.push_back(ComputeCapability::Create(std::move(sub_graph))); number_of_trt_nodes += static_cast(group.first.size()); subgraph_index++; } @@ -2699,6 +2703,30 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, return result; } +std::unique_ptr TensorrtExecutionProvider::CreateOptimizationComputeCapability(ComputeCapability* selection_cc, + std::unordered_set& trt_selection_node_set, + ComputeCapability* trt_cc) const { + auto sub_graph = onnxruntime::IndexedSubGraph::Create(); + std::unordered_set selection_node_set; + + for (auto index : selection_cc->SubGraph()->Nodes()) { + selection_node_set.insert(index); + } + + for (auto index : trt_cc->SubGraph()->Nodes()) { + if (selection_node_set.find(index) == selection_node_set.end()) { + continue; + } + if (trt_selection_node_set.find(index) == trt_selection_node_set.end()) { + continue; + } + sub_graph->Nodes().push_back(index); + } + auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); + compute_capability->copy_optimization_func(selection_cc); + return compute_capability; +} + /** * Refit the weight-stripped engine */ diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 39dc588ae8459..8fc1827ea6e4c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -592,6 +592,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { */ nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; - void CreateConsumerToDqMap(const GraphViewer& graph, std::unordered_map& map) const; + void CreateConsumerToDqMap(const GraphViewer& graph, std::unordered_set& selection_node_set, std::unordered_map& consumer_to_dq) const; + std::unique_ptr TensorrtExecutionProvider::CreateOptimizationComputeCapability(ComputeCapability* selection_cc, + std::unordered_set& trt_selection_node_set, + ComputeCapability* trt_cc) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index 523f6a544f619..84c2a5e376f29 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -259,16 +259,37 @@ void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { graph.SetInputs(graph_inputs_including_initializers); } -void TensorrtExecutionProvider::CreateConsumerToDqMap(const GraphViewer& graph, std::unordered_map& map) const { - LOGS_DEFAULT(VERBOSE) << "Create consumer node to DQ node map ..."; +void TensorrtExecutionProvider::CreateConsumerToDqMap(const GraphViewer& graph, + std::unordered_set& selection_node_set, + std::unordered_map& consumer_to_dq) const { + LOGS_DEFAULT(VERBOSE) << "Select qualified DQ nodes ..."; const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); for (auto index : node_index) { auto* node = graph.GetNode(index); - if (node->OpType() == "DequantizeLinear" && node->GetOutputEdgesCount() == 1) { // DQ does not produce graph output, single consumer + if (!node) { + continue; + } + + const auto* input_def = node->InputDefs()[0]; // Get NodeArg of the initializer of the DequantizeLinear node; + auto data_type = input_def->TypeAsProto()->tensor_type().elem_type(); + auto constant_initializer = graph.IsConstantInitializer(input_def->Name(), true); + + // Node selection: (i.e. initializer -> DQ -> bias of X) + // 1. DequantizeLinear op + // 2. DQ node does not produce graph output, single consumer + // 3. The fist input of DQ is constant initializer. + // 4. The data type of initializer is INT32, UINT16 or INT16 + // 4. X should be Gemm, Conv or LayerNormalization ? + if (node->OpType() == "DequantizeLinear" && + node->GetOutputEdgesCount() == 1 && + (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16 || data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) && + constant_initializer) { const Node& consumer_node = *node->OutputNodesBegin(); - map[consumer_node.Index()] = index; + selection_node_set.insert(index); + consumer_to_dq[consumer_node.Index()] = index; LOGS_DEFAULT(VERBOSE) << consumer_node.Name() << " <- " << node->Name(); } + LOGS_DEFAULT(VERBOSE) << "Total " << selection_node_set.size() << " DequantizeLinear nodes are selected."; } } } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 7bc9f5cf53fe1..959c9743dbfee 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -263,7 +263,7 @@ struct ProviderHostImpl : ProviderHost { static const std::string kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ = "ConstantFoldingDQ"; // ConstantFoldingDQ's optimization function - auto constant_folding_dq_optimization = [&](Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) -> Status { + auto constant_folding_dq_optimization = [&](Graph& graph, const ComputeCapability& optimization_cc, ComputeCapability& cc_to_update) -> Status { std::string optimizer_name = kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ; auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); std::unordered_set original_initializers_to_remove; @@ -274,7 +274,7 @@ struct ProviderHostImpl : ProviderHost { // 1. get original initializers to remove // 2. add new initializers // 3. create dq node index set - for (const auto& index : this_optimization.sub_graph->nodes) { + for (const auto& index : optimization_cc.sub_graph->nodes) { auto node = graph.GetNode(index); if (node->OpType() != "DequantizeLinear") { continue; @@ -339,7 +339,7 @@ struct ProviderHostImpl : ProviderHost { InitializedTensorSet constant_inputs; const InlinedHashSet excluded_initializers; - // Select DequantizeLinear node which dequantizes the bias/constant of Conv, Gemm, LayerNormalization node ... (i.e. initializer -> DQ -> bias of X): + // Select DequantizeLinear node where all inputs are constant for (const auto& index : node_index) { const auto& node = graph_viewer.GetNode(index); if (node->OpType() != "DequantizeLinear") { @@ -362,7 +362,6 @@ struct ProviderHostImpl : ProviderHost { optimizer_to_selection_function[kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ] = constant_folding_dq_selection; } - // auto transformer = graph_transformer_mgr->GetTransformerByName(optimizer_name); auto look_up = optimizer_to_selection_function.find(optimizer_name); if (look_up != optimizer_to_selection_function.end()) { selection_func = optimizer_to_selection_function[optimizer_name]; @@ -883,6 +882,7 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } std::unique_ptr& ComputeCapability__SubGraph(ComputeCapability* p) override { return p->sub_graph; } + void ComputeCapability__copy_optimization_func(ComputeCapability* p, ComputeCapability* selection_cc) override { p->optimization_func = selection_cc->optimization_func; } void ComputeCapability__add_nodes_to_optimize(ComputeCapability* p, std::unique_ptr optimization_cc) override { p->nodes_to_optimize.push_back(std::move(optimization_cc)); } // DataTransferManager (wrapped) From d0cbc653822d79fcd8b9ee403c809dc6b46a6129 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 28 Jan 2025 09:05:07 -0800 Subject: [PATCH 06/24] add comments --- .../tensorrt/tensorrt_execution_provider.cc | 46 +++++++------------ .../tensorrt/tensorrt_execution_provider.h | 25 ++++++++-- .../tensorrt_execution_provider_helper.cc | 42 ++++++++++++++++- .../core/session/provider_bridge_ort.cc | 8 ++-- 4 files changed, 82 insertions(+), 39 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 452fa545aef6a..c9499a10c3479 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2657,8 +2657,19 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } } - // Enable EP related L2+ graph optimizations: - // 1. Dequantize INT32, UINT16, INT16 constant to FP32 (Apply constant folding on DQ nodes) + /** + * Enable EP related L2+ graph optimizations with steps: + * + * 1. call provider bridge API to lookup pre-defined optimizer by name and get selection function + * - Run selection function to get selection ComputeCapability + * - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization + * + * + * + * Current available optimizations: + * - (ConstantFoldingDQ) constant folding on DQ nodes -> Dequantize INT32, UINT16, INT16 constant to FP32. + */ + std::function>(const GraphViewer&)> selection_func; auto status = g_host->GetEPOptimizerByName("ConstantFoldingDQ", graph_transformer_mgr, selection_func); std::vector> selection_cc; @@ -2666,14 +2677,15 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, selection_cc = selection_func(graph); } - std::unordered_set trt_selection_node_set; + std::unordered_set trt_selection_node_set; // The qualified dq nodes selected by TRT EP std::unordered_map consumer_to_dq; // consumer node -> dq node - CreateConsumerToDqMap(graph, trt_selection_node_set, consumer_to_dq); + SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq); - // Create compute capability + // Create ComputeCapability int number_of_trt_nodes = 0, subgraph_index = 0; for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { + // TODO: Use consumer_to_dq table to include DQ node that is filtered out by TRT parser. std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); @@ -2703,30 +2715,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, return result; } -std::unique_ptr TensorrtExecutionProvider::CreateOptimizationComputeCapability(ComputeCapability* selection_cc, - std::unordered_set& trt_selection_node_set, - ComputeCapability* trt_cc) const { - auto sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_set selection_node_set; - - for (auto index : selection_cc->SubGraph()->Nodes()) { - selection_node_set.insert(index); - } - - for (auto index : trt_cc->SubGraph()->Nodes()) { - if (selection_node_set.find(index) == selection_node_set.end()) { - continue; - } - if (trt_selection_node_set.find(index) == trt_selection_node_set.end()) { - continue; - } - sub_graph->Nodes().push_back(index); - } - auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); - compute_capability->copy_optimization_func(selection_cc); - return compute_capability; -} - /** * Refit the weight-stripped engine */ diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 8fc1827ea6e4c..2908c3061f47c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -592,9 +592,26 @@ class TensorrtExecutionProvider : public IExecutionProvider { */ nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; - void CreateConsumerToDqMap(const GraphViewer& graph, std::unordered_set& selection_node_set, std::unordered_map& consumer_to_dq) const; - std::unique_ptr TensorrtExecutionProvider::CreateOptimizationComputeCapability(ComputeCapability* selection_cc, - std::unordered_set& trt_selection_node_set, - ComputeCapability* trt_cc) const; + /** + * This is the helper function for ConstantFoldingDQ graph transformer. + * + * It selects the qualified/required DQ node to be optimized as well as provides a mapping table + * to help TRT EP later include the DQ node which is filtered out by TRT parser. + */ + void SelectQualifiedDQNode(const GraphViewer& graph, + std::unordered_set& selection_node_set, + std::unordered_map& consumer_to_dq) const; + + /** + * This function returns an optimization ComputeCapability that is limited to: + * 1. the DQ nodes in this individual TRT ComputeCapability + * 2. the DQ nodes that are qualified and selected by TRT EP + * + * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. + * Finally, copy the optimization function from the original selection ComputeCapability. + */ + std::unique_ptr CreateOptimizationComputeCapability(ComputeCapability* selection_cc, + std::unordered_set& trt_selection_node_set, + ComputeCapability* trt_cc) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index 84c2a5e376f29..b819711679b48 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -259,7 +259,13 @@ void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { graph.SetInputs(graph_inputs_including_initializers); } -void TensorrtExecutionProvider::CreateConsumerToDqMap(const GraphViewer& graph, +/** + * This is the helper function for ConstantFoldingDQ graph transformer. + * + * It selects the qualified/required DQ node to be optimized as well as provides a mapping table + * to help TRT EP later include the DQ node which is filtered out by TRT parser. + */ +void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, std::unordered_set& selection_node_set, std::unordered_map& consumer_to_dq) const { LOGS_DEFAULT(VERBOSE) << "Select qualified DQ nodes ..."; @@ -289,7 +295,39 @@ void TensorrtExecutionProvider::CreateConsumerToDqMap(const GraphViewer& graph, consumer_to_dq[consumer_node.Index()] = index; LOGS_DEFAULT(VERBOSE) << consumer_node.Name() << " <- " << node->Name(); } - LOGS_DEFAULT(VERBOSE) << "Total " << selection_node_set.size() << " DequantizeLinear nodes are selected."; } + LOGS_DEFAULT(VERBOSE) << "Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected."; +} + +/** + * This function returns an optimization ComputeCapability that is limited to: + * 1. the DQ nodes in this individual TRT ComputeCapability + * 2. the DQ nodes that are qualified and selected by TRT EP + * + * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. + * Finally, copy the optimization function from the original selection ComputeCapability. + */ +std::unique_ptr TensorrtExecutionProvider::CreateOptimizationComputeCapability(ComputeCapability* selection_cc, + std::unordered_set& trt_selection_node_set, + ComputeCapability* trt_cc) const { + auto sub_graph = onnxruntime::IndexedSubGraph::Create(); + std::unordered_set selection_node_set; + + for (auto index : selection_cc->SubGraph()->Nodes()) { + selection_node_set.insert(index); + } + + for (auto index : trt_cc->SubGraph()->Nodes()) { + if (selection_node_set.find(index) == selection_node_set.end()) { + continue; + } + if (trt_selection_node_set.find(index) == trt_selection_node_set.end()) { + continue; + } + sub_graph->Nodes().push_back(index); + } + auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); + compute_capability->copy_optimization_func(selection_cc); + return compute_capability; } } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 959c9743dbfee..fbe299f68f717 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -259,10 +259,10 @@ struct ProviderHostImpl : ProviderHost { static const GraphTransformerManager& graph_transformer_mgr = transformer_mgr; std::string optimizer_name(name); - // pre-defined graph transformers/optimizers + // Pre-defined graph transformers/optimizers static const std::string kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ = "ConstantFoldingDQ"; - // ConstantFoldingDQ's optimization function + // ConstantFoldingDQ optimization function auto constant_folding_dq_optimization = [&](Graph& graph, const ComputeCapability& optimization_cc, ComputeCapability& cc_to_update) -> Status { std::string optimizer_name = kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ; auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); @@ -331,7 +331,7 @@ struct ProviderHostImpl : ProviderHost { return Status::OK(); }; - // ConstantFoldingDQ's selection function + // ConstantFoldingDQ selection function auto constant_folding_dq_selection = [&](const GraphViewer& graph_viewer) -> std::vector> { std::vector> result; std::unique_ptr sub_graph = std::make_unique(); @@ -356,7 +356,7 @@ struct ProviderHostImpl : ProviderHost { return result; }; - // optimizer lookup table + // Optimizer lookup table static std::unordered_map>(const GraphViewer&)>> optimizer_to_selection_function; if (optimizer_to_selection_function.find(kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ) == optimizer_to_selection_function.end()) { optimizer_to_selection_function[kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ] = constant_folding_dq_selection; From b239db05f87680e737015dd008237816a060dfde Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 28 Jan 2025 09:31:19 -0800 Subject: [PATCH 07/24] remove unnecessary code --- .../onnxruntime/core/framework/execution_provider.h | 2 -- onnxruntime/core/framework/compute_capability.h | 6 +++--- onnxruntime/core/framework/graph_partitioner.cc | 13 +++---------- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index c1113e9fe7af2..08c4ef2aea531 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -140,8 +140,6 @@ class IExecutionProvider { const IKernelLookup& kernel_lookup, const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const; - virtual bool RequestCustomizedGraphOptimizationForEP() const { return false; } - /** Get kernel registry per execution provider type. The KernelRegistry share pointer returned is shared across sessions. diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index 34f315878d4a8..2d1d4e0e0153f 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -25,9 +25,9 @@ struct ComputeCapability { ComputeCapability(std::unique_ptr t_sub_graph) : sub_graph(std::move(t_sub_graph)) {} - // optional function to optimize this ComputeCapability - // this will be called by ORT once the ComputeCapability is assigned to the EP - // Optimization: std::function + // Optional function to optimize this ComputeCapability. + // This will be called by ORT once the ComputeCapability is assigned to the EP + // Optimization: std::function std::function optimization_func; // optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index bc84e9ab9289a..a8f33ff819be2 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -141,13 +141,6 @@ auto get_capabilities = [](const IExecutionProvider& ep, const onnxruntime::GraphTransformerManager& graph_transformer_manager) { std::vector> capabilities; capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_transformer_manager); - /* - if (ep.RequestCustomizedGraphOptimizationForEP()) { - capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_transformer_manager); - } else { - //capabilities = ep.GetCapability(graph_viewer, kernel_lookup); - } - */ // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), @@ -452,13 +445,13 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, bool subgraph_assigned_to_ep = false; Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id, &subgraph_assigned_to_ep); - // If the subgraph is assigned to the ep and the ComputeCapability has nodes_to_optimize, - // run EP related optimizations and update compute capability (cc). + // If the subgraph is assigned to the EP and the ComputeCapability has nodes_to_optimize, + // run EP related optimizations and update ComputeCapability. if (subgraph_assigned_to_ep && !capability->nodes_to_optimize.empty()) { for (auto& optimization_cc : capability->nodes_to_optimize) { if (optimization_cc->optimization_func) { optimization_cc->optimization_func(graph, *optimization_cc, *capability); - // #TODO: Handle nested optimization func? + // #TODO: Handle nested optimization ComputeCapability } } } From a83dd111b7fd443989db60fede7b2d6f790a9949 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 28 Jan 2025 21:23:10 -0800 Subject: [PATCH 08/24] remove commented code --- .../core/framework/graph_partitioner.cc | 4 -- .../core/session/provider_bridge_ort.cc | 37 ------------------- 2 files changed, 41 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index a8f33ff819be2..41a0ce49061d1 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -363,10 +363,6 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability, return result; } -static Status TransformGraph(Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) { - -} - // for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up). // assign any nodes to the EP that are currently unassigned, and that the EP can handle. static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index fbe299f68f717..b5741f08513fe 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -204,43 +204,6 @@ common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_na } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -/* -Status ApplyConstantFoldingDQ(const Graph&, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) { - auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); - return Status::OK(); -} - -std::vector> ConstantFoldingDQ(const GraphViewer& graph_viewer) { - std::vector> result; - std::unique_ptr sub_graph = std::make_unique(); - const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); - for (const auto& index : node_index) { - const auto& node = graph_viewer.GetNode(index); - if (node->OpType() != "DequantizeLinear") { - continue; - } - sub_graph->nodes.push_back(index); - std::cout << node->Name() << ", op type: " << node->OpType() << std::endl; - } - - result.push_back(std::make_unique(std::move(sub_graph))); - result.back()->optimization_func = ApplyConstantFoldingDQ; - return result; -} - -Status GetPredefinedEPGraphTransformersForLookUp(std::unordered_map>(const GraphViewer&)>>& map) { - static const std::string kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ = "ConstantFoldingDQ"; - static std::unordered_map>(const GraphViewer&)>> ep_transformers_map; - - if (ep_transformers_map.find(kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ) == ep_transformers_map.end()) { - ep_transformers_map[kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ] = ConstantFoldingDQ; - } - - map = ep_transformers_map; - return Status::OK(); -} -*/ - const GraphTransformerManager* graph_transformer_manager; #if defined(_MSC_VER) && !defined(__clang__) From 372342c82acad22416208702c3a98953c7f81ef3 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 28 Jan 2025 21:24:03 -0800 Subject: [PATCH 09/24] add a function to include DQ that is filtered out by TRT parser --- .../tensorrt/tensorrt_execution_provider.cc | 51 +++++++++++++++++-- .../tensorrt_execution_provider_helper.cc | 6 +-- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c9499a10c3479..ef9a37125a17e 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2555,8 +2555,8 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } bool early_termination = false; - //supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); - supported_nodes_vector = parser_nodes_vector; + supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); + //supported_nodes_vector = parser_nodes_vector; if (early_termination) { supported_nodes_vector.clear(); } @@ -2660,7 +2660,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, /** * Enable EP related L2+ graph optimizations with steps: * - * 1. call provider bridge API to lookup pre-defined optimizer by name and get selection function + * 1. Call provider bridge API to lookup pre-defined optimizer by name and get selection function * - Run selection function to get selection ComputeCapability * - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization * @@ -2679,8 +2679,51 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, std::unordered_set trt_selection_node_set; // The qualified dq nodes selected by TRT EP std::unordered_map consumer_to_dq; // consumer node -> dq node + // Note: The NodeIndex here is the node index in the graph, not the index in node vector in supported_nodes_vector. + SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq); + // Include nodes that are filtered out by TRT parser. + auto update_supported_node_vector = [&](SubGraph_t& supported_node_vector, SubGraphCollection_t& supported_nodes_vector) -> void { + if (!consumer_to_dq.empty()) { + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); + for (auto index : supported_node_vector.first) { + if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) { + continue; + } + + auto dq_node_index = consumer_to_dq[node_index[index]]; + + // Check if DQ node is included in one of the subgraphs + auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { + for (auto& node_vector : supported_nodes_vector) { + if (!node_vector.second) { + continue; + } + for (auto index : node_vector.first) { + if (node_index[index] == node_idx) { + return true; + } + } + } + return false; + }; + if (in_the_subgraph_collection(dq_node_index)) { + continue; + } + // Find the iterator pointing to the target element + auto it = std::find(node_index.begin(), node_index.end(), dq_node_index); + if (it != node_index.end()) { + // Calculate the index + int idx = std::distance(node_index.begin(), it); + supported_node_vector.first.push_back(static_cast(idx)); + auto node = graph.GetNode(dq_node_index); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; + } + } + } + }; + // Create ComputeCapability int number_of_trt_nodes = 0, subgraph_index = 0; for (const auto& group : supported_nodes_vector) { @@ -2689,7 +2732,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); - // add optimization compute capability to node_to_optimize + // add optimization ComputeCapability to node_to_optimize for (auto& cc : selection_cc) { std::unique_ptr optimization_cc = CreateOptimizationComputeCapability(cc.get(), trt_selection_node_set, compute_capability.get()); compute_capability->add_nodes_to_optimize(std::move(optimization_cc)); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index b819711679b48..726988bef1552 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -268,7 +268,7 @@ void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, std::unordered_set& selection_node_set, std::unordered_map& consumer_to_dq) const { - LOGS_DEFAULT(VERBOSE) << "Select qualified DQ nodes ..."; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Select qualified DQ nodes ..."; const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); for (auto index : node_index) { auto* node = graph.GetNode(index); @@ -293,10 +293,10 @@ void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, const Node& consumer_node = *node->OutputNodesBegin(); selection_node_set.insert(index); consumer_to_dq[consumer_node.Index()] = index; - LOGS_DEFAULT(VERBOSE) << consumer_node.Name() << " <- " << node->Name(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " < -" << node->Name(); } } - LOGS_DEFAULT(VERBOSE) << "Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected."; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected."; } /** From 39fa8973b6bcf60d48fd5acd8f146e1fa4b1742a Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 3 Feb 2025 14:25:09 -0800 Subject: [PATCH 10/24] add standalone GraphOptimizerRegistry as singleton --- .../core/framework/execution_provider.h | 8 --- .../core/optimizer/graph_transformer_utils.h | 9 ++- .../core/framework/execution_provider.cc | 3 +- .../core/framework/graph_partitioner.cc | 32 +++------- .../core/framework/graph_partitioner.h | 5 +- .../optimizer/graph_optimizer_registry.cc | 62 +++++++++++++++++++ .../core/optimizer/graph_optimizer_registry.h | 47 ++++++++++++++ .../core/optimizer/graph_transformer_mgr.h | 1 - .../core/optimizer/graph_transformer_utils.cc | 19 +++++- .../providers/cuda/cuda_execution_provider.cc | 3 +- .../providers/cuda/cuda_execution_provider.h | 3 +- .../providers/shared_library/provider_api.h | 1 - .../provider_bridge_provider.cc | 5 +- .../shared_library/provider_interfaces.h | 8 +-- .../tensorrt/tensorrt_execution_provider.cc | 9 ++- .../tensorrt/tensorrt_execution_provider.h | 3 +- onnxruntime/core/session/inference_session.cc | 32 ++++++++++ onnxruntime/core/session/inference_session.h | 7 +++ .../core/session/provider_bridge_ort.cc | 18 +++--- .../test/framework/inference_session_test.cc | 3 +- .../test/framework/session_state_test.cc | 9 +-- .../internal_testing_execution_provider.cc | 3 +- .../internal_testing_execution_provider.h | 3 +- 23 files changed, 212 insertions(+), 81 deletions(-) create mode 100644 onnxruntime/core/optimizer/graph_optimizer_registry.cc create mode 100644 onnxruntime/core/optimizer/graph_optimizer_registry.h diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 08c4ef2aea531..0d9e6db1a7748 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -20,7 +20,6 @@ struct ComputeCapability; class KernelRegistry; struct KernelCreateInfo; class Node; -class GraphTransformerManager; } // namespace onnxruntime #else #include @@ -129,16 +128,9 @@ class IExecutionProvider { For kernels registered in a kernel registry, `kernel_lookup` must be used to find a matching kernel for this EP. */ - /* virtual std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& kernel_lookup) const; - */ - - virtual std::vector> - GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup, - const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const; /** Get kernel registry per execution provider type. diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e4a5b011dfeb6..aa2d42ec29698 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -59,11 +59,18 @@ InlinedVector> GenerateTransformers( std::unordered_map>* p_buffered_tensors = nullptr); InlinedVector> GenerateTransformersForEP( - TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ const logging::Logger& logger); +/* +InlinedVector> GenerateTransformersForEP( + TransformerLevel level, + const SessionOptions& session_options, + const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger); +*/ + #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index cae3a37e4146a..b39924d4c3ff9 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -13,8 +13,7 @@ namespace onnxruntime { std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup, - const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const { + const IKernelLookup& kernel_lookup) const { std::vector> result; for (const auto& node : graph.Nodes()) { if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node); diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 41a0ce49061d1..4cd28d18b765a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -12,6 +12,7 @@ #include "core/framework/kernel_lookup.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/graph/function.h" #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" @@ -58,7 +59,6 @@ struct PartitionParams { std::reference_wrapper fused_node_unique_id; std::reference_wrapper transform_layout_function; std::reference_wrapper debug_graph_fn; - std::reference_wrapper graph_transformer_manager; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; } // namespace @@ -131,16 +131,14 @@ struct GetCapabilityForEPParams { GraphPartitioner::Mode mode; std::reference_wrapper transform_layout; std::reference_wrapper debug_graph_fn; - std::reference_wrapper graph_transformer_manager; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; auto get_capabilities = [](const IExecutionProvider& ep, const GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::GraphTransformerManager& graph_transformer_manager) { + const IExecutionProvider::IKernelLookup& kernel_lookup) { std::vector> capabilities; - capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_transformer_manager); + capabilities = ep.GetCapability(graph_viewer, kernel_lookup); // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), @@ -174,11 +172,10 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); - auto& graph_transformer_manager = params.graph_transformer_manager.get(); { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_manager); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); if (capabilities.empty()) { return Status::OK(); @@ -216,7 +213,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_manager); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -253,7 +250,6 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l // It also does not perform layout transformation. This will be done during normal partitioning. static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, - const GraphTransformerManager& graph_transformer_mgr, const IExecutionProvider& current_ep, const logging::Logger& logger, std::vector>& capabilities) { @@ -264,9 +260,10 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, kernel_registries_for_ep, kernel_registry_mgr.GetKernelTypeStrResolver(), logger}; + auto graph_optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); // TODO: Provide EP with a capability to look inside the functions. - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_mgr); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); return Status::OK(); } @@ -373,7 +370,6 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, - const onnxruntime::GraphTransformerManager& graph_transfomer_manager, const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability @@ -388,7 +384,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, graph_transfomer_manager, logger)); + transform_layout_fn, debug_graph_fn, logger)); } } @@ -411,8 +407,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::ref(capabilities), mode, std::cref(transform_layout_fn), - std::cref(debug_graph_fn), - std::cref(graph_transfomer_manager)}; + std::cref(debug_graph_fn)}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -587,7 +582,6 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, - const GraphTransformerManager& graph_transformer_mgr, Graph& graph, const logging::Logger& logger, InlinedHashSet& not_inlined, @@ -604,7 +598,6 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, - graph_transformer_mgr, *subgraph, logger, not_inlined, @@ -630,7 +623,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, graph_transformer_mgr, * ep, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, * ep, logger, capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; @@ -770,7 +763,6 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, auto& fused_kernel_registry = partition_params.fused_kernel_registry.get(); auto& fused_node_unique_id = partition_params.fused_node_unique_id.get(); const auto& transform_layout_function = partition_params.transform_layout_function; - const auto& graph_transformer_manager = partition_params.graph_transformer_manager; do { // process full graph with each EP @@ -779,7 +771,6 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, - graph_transformer_manager, logger)); } @@ -831,7 +822,6 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param GraphPartitioner::Mode::kOrtFormatLoad, std::cref(partition_params.transform_layout_function), std::cref(partition_params.debug_graph_fn), - std::cref(partition_params.graph_transformer_manager), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; // clang-format on @@ -947,7 +937,6 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, size_t inlined_count = 0; ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, - graph_transformer_mgr_, graph, logger, not_inlined, @@ -1006,7 +995,6 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::ref(fused_node_unique_id), std::cref(transform_layout_function), std::cref(debug_graph_fn), - std::cref(graph_transformer_mgr_), }; #else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 8fcc1de782a6a..d1ef193cf1520 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -7,7 +7,6 @@ #include "core/graph/graph.h" #include "core/framework/fuse_nodes_funcs.h" #include "core/framework/transform_layout_functions.h" -#include "core/optimizer/graph_transformer_mgr.h" namespace onnxruntime { @@ -25,9 +24,8 @@ class GraphPartitioner { }; // The order of providers represents the user preference. - GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const GraphTransformerManager& graph_transformer_mgr, const ExecutionProviders& providers) + GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers) : kernel_registry_mgr_(kernel_registry_mgr), - graph_transformer_mgr_(graph_transformer_mgr), providers_(providers) { } @@ -66,7 +64,6 @@ class GraphPartitioner { KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; - const GraphTransformerManager& graph_transformer_mgr_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc new file mode 100644 index 0000000000000..2811acdeeaeed --- /dev/null +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.cc @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/graph_optimizer_registry.h" +#include "core/optimizer/graph_transformer_utils.h" + +using namespace onnxruntime; +using namespace ::onnxruntime::common; + +namespace onnxruntime { + +common::Status GraphOptimizerRegistry::Register(std::unique_ptr transformer) { + const auto& name = transformer->Name(); + if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer is already registered " + name); + } + + name_to_transformer_map_[name] = transformer.get(); + transformer_list_.push_back(std::move(transformer)); + return Status::OK(); +} + +GraphTransformer* GraphOptimizerRegistry::GetTransformerByName(std::string& name) const { + if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) { + return name_to_transformer_map_.at(name); + } + return nullptr; +} + +// Registers all the predefined transformers for EP +common::Status GraphOptimizerRegistry::AddPredefinedOptimizers( + const onnxruntime::SessionOptions& sess_options, + const onnxruntime::IExecutionProvider& cpu_ep, + const logging::Logger& logger) { + // TODO: Apply optimization level here if we later decide to do so + auto transformers_to_register = [&]() { + return optimizer_utils::GenerateTransformersForEP(sess_options, cpu_ep, logger); + }(); + + for (auto& entry : transformers_to_register) { + ORT_RETURN_IF_ERROR(Get()->Register(std::move(entry))); + } + return Status::OK(); +} + +common::Status GraphOptimizerRegistry::ApplyTransformer(Graph& graph, std::string& name, + const logging::Logger& logger) const { + auto transformer = GetTransformerByName(name); + if (!transformer) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer is not registered " + name); + } + + bool modified = false; + ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); + + return Status::OK(); +} + +// Initialize static members +std::shared_ptr onnxruntime::GraphOptimizerRegistry::graph_optimizer_registry = nullptr; +std::mutex GraphOptimizerRegistry::registry_mutex; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.h b/onnxruntime/core/optimizer/graph_optimizer_registry.h new file mode 100644 index 0000000000000..fc425b50de44d --- /dev/null +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" +//#include "core/common/common.h" +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" + +namespace onnxruntime { +class GraphOptimizerRegistry { + public: + explicit GraphOptimizerRegistry() {} + GraphOptimizerRegistry(const GraphOptimizerRegistry&) = delete; + + static std::shared_ptr Get() { + if (!graph_optimizer_registry) { // First Check (without locking) + std::lock_guard lock(registry_mutex); + if (!graph_optimizer_registry) { // Second Check (with locking) + graph_optimizer_registry = std::make_shared(); + } + } + return graph_optimizer_registry; + } + + common::Status AddPredefinedOptimizers(const onnxruntime::SessionOptions& sess_options, + const onnxruntime::IExecutionProvider& cpu_ep, + const logging::Logger& logger); + + common::Status ApplyTransformer(Graph& graph, std::string& name, + const logging::Logger& logger) const; + + common::Status Register(std::unique_ptr transformer); + + // Get transformer by name. Return nullptr if not found. + GraphTransformer* GetTransformerByName(std::string& name) const; + + private: + InlinedVector> transformer_list_; + InlinedHashMap name_to_transformer_map_; + + static std::shared_ptr graph_optimizer_registry; + static std::mutex registry_mutex; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h index f2a6b77f5fc39..06fb341b8df43 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.h +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h @@ -10,7 +10,6 @@ #include "core/optimizer/rewrite_rule.h" namespace onnxruntime { - // Manages a list of graph transformers. It is initialized with a list of graph // transformers. Each inference session can further register additional ones. class GraphTransformerManager { diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 854d8847ea7c6..899f062f055bf 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -187,10 +187,24 @@ std::unique_ptr GenerateRuleBasedGraphTransformer( return rule_transformer; } +InlinedVector> GenerateTransformersForEP( + const SessionOptions& session_options, + const IExecutionProvider& cpu_execution_provider, /*required by constant folding DQ*/ + const logging::Logger& logger) { + InlinedVector> transformers; + + // TODO: Apply optimization level here if we later decide to do so + const InlinedHashSet node_index_set = {}; + transformers.emplace_back(std::make_unique(cpu_execution_provider, false /*skip_dequantize_linear*/, + session_options.config_options, node_index_set)); + return transformers; +} + +/* InlinedVector> GenerateTransformersForEP( TransformerLevel level, const SessionOptions& session_options, - const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const IExecutionProvider& cpu_execution_provider, const logging::Logger& logger) { InlinedVector> transformers; switch (level) { @@ -199,7 +213,7 @@ InlinedVector> GenerateTransformersForEP( } case TransformerLevel::Level2: { const InlinedHashSet node_index_set = {}; - transformers.emplace_back(std::make_unique(cpu_execution_provider, false /*skip_dequantize_linear*/, + transformers.emplace_back(std::make_unique(cpu_execution_provider, false, session_options.config_options, node_index_set)); break; } @@ -211,6 +225,7 @@ InlinedVector> GenerateTransformersForEP( } return transformers; } +*/ InlinedVector> GenerateTransformers( TransformerLevel level, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 5b923bcd31b49..4a10de153653c 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2658,8 +2658,7 @@ std::unique_ptr CUDAExecutionProvider::GetDataTransf std::vector> CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup, - const GraphTransformerManager& graph_transformer_mgr) const { + const IKernelLookup& kernel_lookup) const { InlinedVector candidates; // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. InlinedVector tentative_nodes; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index b9f06e136ad17..bd2be2eac2181 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -72,8 +72,7 @@ class CUDAExecutionProvider : public IExecutionProvider { std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const IKernelLookup& kernel_lookup, - const GraphTransformerManager& graph_transformer_mgr) const override; + const IKernelLookup& kernel_lookup) const override; int GetDeviceId() const override { return info_.device_id; } const cudaDeviceProp& GetDeviceProp() const { return device_prop_; }; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 3d1f78e14fcdc..45f81ed22b7f7 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -175,7 +175,6 @@ struct SparseTensor; class TensorSeq; class SessionState; class ModelMetadefIdGenerator; -class GraphTransformerManager; class If; class Loop; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 380be5bef7904..aa8c367d25d51 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -331,9 +331,8 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz } std::vector> IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup, - const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const { - return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup, graph_transformer_mgr); + const IKernelLookup& kernel_lookup) const { + return g_host->IExecutionProvider__GetCapability(this, graph_viewer, kernel_lookup); } common::Status IExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 76c650276642e..16fc15ea76725 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -142,9 +142,8 @@ struct Node__EdgeIterator { struct ProviderHost { virtual const OrtApiBase* OrtGetApiBase() = 0; - virtual Status GetEPOptimizerByName(const std::string& name, - const GraphTransformerManager& graph_transformer_mgr, - std::function>(const GraphViewer&)>& selection_func) = 0; + virtual Status GetOptimizerByName(const std::string& name, + std::function>(const GraphViewer&)>& selection_func) = 0; virtual void* HeapAllocate(size_t size) = 0; virtual void HeapFree(void*) = 0; @@ -250,8 +249,7 @@ struct ProviderHost { // IExecutionProvider virtual std::vector> IExecutionProvider__GetCapability(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::GraphTransformerManager& graph_transformer_mgr) = 0; + const IExecutionProvider::IKernelLookup& kernel_lookup) = 0; virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index ef9a37125a17e..8a899884e5892 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2451,8 +2451,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& std::vector> TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, - const IKernelLookup&, /*kernel_lookup*/ - const GraphTransformerManager& graph_transformer_mgr) const { + const IKernelLookup& /*kernel_lookup*/) const { // Construct subgraph capability from node list std::vector> result; // Get ModelPath @@ -2555,8 +2554,8 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } bool early_termination = false; - supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); - //supported_nodes_vector = parser_nodes_vector; + //supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); + supported_nodes_vector = parser_nodes_vector; if (early_termination) { supported_nodes_vector.clear(); } @@ -2671,7 +2670,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, */ std::function>(const GraphViewer&)> selection_func; - auto status = g_host->GetEPOptimizerByName("ConstantFoldingDQ", graph_transformer_mgr, selection_func); + auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", selection_func); std::vector> selection_cc; if (selection_func) { selection_cc = selection_func(graph); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 2908c3061f47c..4946058e16d75 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -247,8 +247,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const GraphViewer& graph, - const IKernelLookup&, /*kernel_lookup*/ - const GraphTransformerManager& graph_transformer_mgr) const override; + const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return device_id_; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f1020baf7114e..1b43a16d3a774 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -41,6 +41,7 @@ #include "core/graph/model_saving_options.h" #include "core/optimizer/graph_transformer_utils.h" #include "core/optimizer/graph_transformer.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/optimizer/layout_transformation/layout_transformation.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" @@ -1850,10 +1851,22 @@ common::Status InferenceSession::Initialize() { record_runtime_optimization_produced_op_schema, *session_logger_)); + // register predefined transformers for EP + auto graph_optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); + graph_optimizer_registry->AddPredefinedOptimizers(session_options_, + *execution_providers_. Get(onnxruntime::kCpuExecutionProvider), + *session_logger_); + /* + ORT_RETURN_IF_ERROR_SESSIONID_(RegisterPredefinedOptimizersForEP(*graph_optimizer_registry, + *session_logger_)); + */ + + /* // add predefined transformers for EP ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformersForEP(ep_graph_transformer_mgr_, session_options_.graph_optimization_level, *session_logger_)); + */ #ifdef USE_DML const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); @@ -3285,6 +3298,24 @@ common::Status InferenceSession::AddPredefinedTransformers( return Status::OK(); } +// Registers all the predefined transformers for EP +common::Status InferenceSession::RegisterPredefinedOptimizersForEP( + GraphOptimizerRegistry& optimizer_registry, + const logging::Logger& logger) const { + const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); + + // TODO: Apply optimization level here if we later decide to do so + auto transformers_to_register = [&]() { + return optimizer_utils::GenerateTransformersForEP(session_options_, cpu_ep, logger); + }(); + + for (auto& entry : transformers_to_register) { + ORT_RETURN_IF_ERROR(optimizer_registry.Register(std::move(entry))); + } + return Status::OK(); +} + +/* // Registers all the predefined transformers for EP with transformer manager common::Status InferenceSession::AddPredefinedTransformersForEP( GraphTransformerManager& transformer_manager, @@ -3306,6 +3337,7 @@ common::Status InferenceSession::AddPredefinedTransformersForEP( } return Status::OK(); } +*/ #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index af5f7384a5d79..ec31c9a22a6aa 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -29,6 +29,7 @@ #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" +#include "core/optimizer/graph_optimizer_registry.h" #include #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" @@ -712,10 +713,16 @@ class InferenceSession { RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, const logging::Logger& logger) const; + virtual common::Status RegisterPredefinedOptimizersForEP( + GraphOptimizerRegistry& optimizer_registry, + const logging::Logger& logger) const; + + /* virtual common::Status AddPredefinedTransformersForEP( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, const logging::Logger& logger) const; + */ common::Status TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index b5741f08513fe..6acc0b4f9b4c1 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -216,10 +216,9 @@ const GraphTransformerManager* graph_transformer_manager; struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } - Status GetEPOptimizerByName(const std::string& name, - const GraphTransformerManager& transformer_mgr, - std::function>(const GraphViewer&)>& selection_func) override { - static const GraphTransformerManager& graph_transformer_mgr = transformer_mgr; + Status GetOptimizerByName(const std::string& name, + std::function>(const GraphViewer&)>& selection_func) override { + //static const GraphTransformerManager& graph_transformer_mgr = transformer_mgr; std::string optimizer_name(name); // Pre-defined graph transformers/optimizers @@ -249,11 +248,13 @@ struct ProviderHostImpl : ProviderHost { dq_node_index_set.insert(index); } - ConstantFoldingDQ* transformer = static_cast(graph_transformer_mgr.GetTransformerByName(optimizer_name)); + auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); + + ConstantFoldingDQ* transformer = static_cast(optimizer_registry->GetTransformerByName(optimizer_name)); transformer->UpdateNodeIndexSet(dq_node_index_set); // apply constant folding on DQ nodes - graph_transformer_mgr.ApplyTransformer(graph, optimizer_name, *logger); + optimizer_registry->ApplyTransformer(graph, optimizer_name, *logger); // update the overall ComputeCapability std::vector updated_nodes; @@ -455,9 +456,8 @@ struct ProviderHostImpl : ProviderHost { // IExecutionProvider (direct) std::vector> IExecutionProvider__GetCapability( const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup, - const onnxruntime::GraphTransformerManager& graph_transformer_mgr) override { - return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup, graph_transformer_mgr); + const IExecutionProvider::IKernelLookup& kernel_lookup) override { + return p->IExecutionProvider::GetCapability(graph_viewer, kernel_lookup); } common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override { diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index a94520077774f..740c566794f15 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -137,8 +137,7 @@ class FuseExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/, - const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const override { + const IKernelLookup& /*kernel_lookup*/) const override { // Fuse two add into one. std::vector> result; std::unique_ptr sub_graph = std::make_unique(); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index a9760e62f217a..eb7553c07eeba 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -261,8 +261,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - GraphTransformerManager graph_transformer_mgr(10); - GraphPartitioner partitioner(krm, graph_transformer_mgr, execution_providers); + GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK( partitioner.Partition( graph, session_state.GetMutableFuncMgr(), @@ -347,9 +346,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - GraphTransformerManager graph_transformer_mgr(10); // Partition the graph - GraphPartitioner partitioner(krm, graph_transformer_mgr, execution_providers); + GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, @@ -406,9 +404,8 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - GraphTransformerManager graph_transformer_mgr(10); // Partition the graph - GraphPartitioner partitioner(krm, graph_transformer_mgr, execution_providers); + GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( graph, session_state.GetMutableFuncMgr(), [&cpu_allocator](Graph& graph, bool& modified, diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc index 31049d2ab7e3c..2e073def5d643 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc @@ -110,8 +110,7 @@ DataLayout InternalTestingExecutionProvider::GetPreferredLayout() const { std::vector> InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const IKernelLookup& kernel_lookup, - const GraphTransformerManager& graph_transformer_mgr) const { + const IKernelLookup& kernel_lookup) const { // find nodes that have ops in our supported list std::unordered_set supported_static_nodes; std::unordered_set supported_compiled_nodes; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h index 3a193224d6309..6615eb82f2b05 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h @@ -19,8 +19,7 @@ class InternalTestingExecutionProvider : public IExecutionProvider { std::vector> GetCapability(const onnxruntime::GraphViewer& graph_view, - const IKernelLookup& /*kernel_lookup*/, - const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const override; + const IKernelLookup& /*kernel_lookup*/) const override; common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; From 627a00ae11ab114eb51206886ec24e3eb2992cb7 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 3 Feb 2025 14:42:30 -0800 Subject: [PATCH 11/24] remove redundant code --- .../core/optimizer/graph_transformer_utils.h | 13 ----- .../core/framework/graph_partitioner.cc | 3 +- onnxruntime/core/session/inference_session.cc | 53 ++----------------- onnxruntime/core/session/inference_session.h | 12 ----- .../core/session/provider_bridge_ort.cc | 2 - 5 files changed, 4 insertions(+), 79 deletions(-) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index aa2d42ec29698..31b0f22340510 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -58,19 +58,6 @@ InlinedVector> GenerateTransformers( concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); -InlinedVector> GenerateTransformersForEP( - const SessionOptions& session_options, - const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ - const logging::Logger& logger); - -/* -InlinedVector> GenerateTransformersForEP( - TransformerLevel level, - const SessionOptions& session_options, - const IExecutionProvider& cpu_execution_provider, - const logging::Logger& logger); -*/ - #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 4cd28d18b765a..2fb639a1d118a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -137,8 +137,7 @@ struct GetCapabilityForEPParams { auto get_capabilities = [](const IExecutionProvider& ep, const GraphViewer& graph_viewer, const IExecutionProvider::IKernelLookup& kernel_lookup) { - std::vector> capabilities; - capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); // In theory an EP could return an empty capability. Remove those. capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 1b43a16d3a774..121b64a951bb9 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -647,7 +647,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const : #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif environment_(session_env) { // Initialize assets of this session instance @@ -661,7 +660,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, : #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif external_intra_op_thread_pool_(external_intra_op_thread_pool), external_inter_op_thread_pool_(external_inter_op_thread_pool), @@ -675,7 +673,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const const PathString& model_uri) : model_location_(model_uri), graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), environment_(session_env) { auto status = Model::Load(model_location_, model_proto_); ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ", @@ -696,7 +693,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, std::istream& model_istream) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), environment_(session_env) { Status st = Model::Load(model_istream, &model_proto_); ORT_ENFORCE(st.IsOK(), "Could not parse model successfully while constructing the inference session"); @@ -708,7 +704,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, const void* model_data, int model_data_len) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - ep_graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), environment_(session_env) { const bool result = model_proto_.ParseFromArray(model_data, model_data_len); ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); @@ -1213,7 +1208,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // 7. insert copy nodes (required transformer). // Run Ahead Of time function inlining - GraphPartitioner partitioner(kernel_registry_manager_, ep_graph_transformer_mgr_, execution_providers_); + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); if (const bool disable_aot_function_inlining = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1"; @@ -1604,7 +1599,6 @@ namespace { Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, - const onnxruntime::GraphTransformerManager& graph_transformer_manager, SessionState& session_state, const ConfigOptions& config_options, const logging::Logger& logger) { @@ -1624,7 +1618,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - GraphPartitioner partitioner(kernel_registry_manager, graph_transformer_manager, providers); + GraphPartitioner partitioner(kernel_registry_manager, providers); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, @@ -2076,7 +2070,7 @@ common::Status InferenceSession::Initialize() { "Loading anything other than ORT format models is not enabled in this build.")); #endif // !defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, graph_transformer_mgr_, + ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, *session_state_, session_options_.config_options, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -3298,47 +3292,6 @@ common::Status InferenceSession::AddPredefinedTransformers( return Status::OK(); } -// Registers all the predefined transformers for EP -common::Status InferenceSession::RegisterPredefinedOptimizersForEP( - GraphOptimizerRegistry& optimizer_registry, - const logging::Logger& logger) const { - const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); - - // TODO: Apply optimization level here if we later decide to do so - auto transformers_to_register = [&]() { - return optimizer_utils::GenerateTransformersForEP(session_options_, cpu_ep, logger); - }(); - - for (auto& entry : transformers_to_register) { - ORT_RETURN_IF_ERROR(optimizer_registry.Register(std::move(entry))); - } - return Status::OK(); -} - -/* -// Registers all the predefined transformers for EP with transformer manager -common::Status InferenceSession::AddPredefinedTransformersForEP( - GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - const logging::Logger& logger) const { - const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); - for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { - TransformerLevel level = static_cast(i); - if (graph_optimization_level >= level) { - // Generate and register transformers for level - auto transformers_to_register = [&]() { - return optimizer_utils::GenerateTransformersForEP(level, session_options_, cpu_ep, logger); - }(); - - for (auto& entry : transformers_to_register) { - ORT_RETURN_IF_ERROR(transformer_manager.Register(std::move(entry), level)); - } - } - } - return Status::OK(); -} -*/ - #endif // !defined(ORT_MINIMAL_BUILD) common::Status InferenceSession::WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index ec31c9a22a6aa..7c380201aa2f3 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -713,21 +713,9 @@ class InferenceSession { RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, const logging::Logger& logger) const; - virtual common::Status RegisterPredefinedOptimizersForEP( - GraphOptimizerRegistry& optimizer_registry, - const logging::Logger& logger) const; - - /* - virtual common::Status AddPredefinedTransformersForEP( - GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - const logging::Logger& logger) const; - */ - common::Status TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format); onnxruntime::GraphTransformerManager graph_transformer_mgr_; - onnxruntime::GraphTransformerManager ep_graph_transformer_mgr_; InlinedHashSet> saved_runtime_optimization_produced_node_op_schemas_; #endif diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 6acc0b4f9b4c1..c954141ae06c7 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -204,8 +204,6 @@ common::Status LoadDynamicLibraryFromProvider(onnxruntime::PathString library_na } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -const GraphTransformerManager* graph_transformer_manager; - #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26436) From 06ca0869a6a9f74ee42083b97d7308bb7a55599d Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 3 Feb 2025 14:48:25 -0800 Subject: [PATCH 12/24] remove redundant code --- .../core/framework/graph_partitioner.cc | 2 +- .../core/optimizer/graph_transformer_utils.cc | 27 ------------------- onnxruntime/core/session/inference_session.cc | 11 -------- .../test/framework/session_state_test.cc | 3 +++ 4 files changed, 4 insertions(+), 39 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 2fb639a1d118a..d7d52798fc3d7 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -622,7 +622,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, * ep, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 899f062f055bf..7be836f9e9f2e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -200,33 +200,6 @@ InlinedVector> GenerateTransformersForEP( return transformers; } -/* -InlinedVector> GenerateTransformersForEP( - TransformerLevel level, - const SessionOptions& session_options, - const IExecutionProvider& cpu_execution_provider, - const logging::Logger& logger) { - InlinedVector> transformers; - switch (level) { - case TransformerLevel::Level1: { - break; - } - case TransformerLevel::Level2: { - const InlinedHashSet node_index_set = {}; - transformers.emplace_back(std::make_unique(cpu_execution_provider, false, - session_options.config_options, node_index_set)); - break; - } - case TransformerLevel::Level3: { - break; - } - default: - ORT_THROW("Unsupported optimization level: ", static_cast(level)); - } - return transformers; -} -*/ - InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 121b64a951bb9..a852c195c1377 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1850,17 +1850,6 @@ common::Status InferenceSession::Initialize() { graph_optimizer_registry->AddPredefinedOptimizers(session_options_, *execution_providers_. Get(onnxruntime::kCpuExecutionProvider), *session_logger_); - /* - ORT_RETURN_IF_ERROR_SESSIONID_(RegisterPredefinedOptimizersForEP(*graph_optimizer_registry, - *session_logger_)); - */ - - /* - // add predefined transformers for EP - ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformersForEP(ep_graph_transformer_mgr_, - session_options_.graph_optimization_level, - *session_logger_)); - */ #ifdef USE_DML const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index eb7553c07eeba..533f7792d52ba 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -261,6 +261,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); + GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK( partitioner.Partition( @@ -346,6 +347,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); + // Partition the graph GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( @@ -404,6 +406,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); + // Partition the graph GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( From a965ffb4945344005f6d5e000912edf5944de371 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 3 Feb 2025 14:49:38 -0800 Subject: [PATCH 13/24] remove redundant code --- onnxruntime/test/framework/session_state_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 533f7792d52ba..e7f8b1aaa49d8 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -347,7 +347,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - + // Partition the graph GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( @@ -406,7 +406,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { SessionState session_state(graph, execution_providers, nullptr, nullptr, dtm, edlm, DefaultLoggingManager().DefaultLogger(), profiler, sess_options); - + // Partition the graph GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK(partitioner.Partition( From 4c2697c687ce0d6d46be481daab7e04440ee1032 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 3 Feb 2025 16:21:43 -0800 Subject: [PATCH 14/24] add back function --- .../onnxruntime/core/optimizer/graph_transformer_utils.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 31b0f22340510..8f7acc8a5e803 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -58,6 +58,12 @@ InlinedVector> GenerateTransformers( concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); +/** Generates all predefined transformers for EPs */ +InlinedVector> GenerateTransformersForEP( + const SessionOptions& session_options, + const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const logging::Logger& logger); + #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) From 2b81789507da1d3df63ecea8a12d2e460d1b1e77 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 4 Feb 2025 14:39:57 -0800 Subject: [PATCH 15/24] changed code per reviewer --- onnxruntime/core/optimizer/constant_folding.h | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 7ce227f86fb04..423dbeae00ffb 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -28,15 +28,19 @@ class ConstantFolding : public GraphTransformer { const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; - /* Same as above but with a name provided by derived class. - */ +protected: + /** + * Same as the constructor above but with a name provided by derived class. + */ ConstantFolding(const std::string& name, const IExecutionProvider& execution_provider, bool skip_dequantize_linear, const ConfigOptions& config_options, const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; - + /** + * Derived class can implement this virtual function to limit the nodes that can be constant folded. + */ virtual bool AllowConstantFolding(const Node& node) const { return true; } private: From 0c10cd43d55f5a8d87ebf8c42602e1bbacf27a57 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 6 Feb 2025 15:16:26 -0800 Subject: [PATCH 16/24] don't create optimizer instances until EP requests it by calling GetOptimizerByName --- .../optimizer/graph_optimizer_registry.cc | 70 ++++++++++- .../core/optimizer/graph_optimizer_registry.h | 65 ++++++++-- .../constant_folding_dq_node.cc | 11 +- .../constant_folding_dq_node.h | 6 +- .../selection_and_optimization_func.cc | 108 ++++++++++++++++ .../selection_and_optimization_func.h | 20 +++ .../shared_library/provider_interfaces.h | 1 + .../tensorrt/tensorrt_execution_provider.cc | 3 +- onnxruntime/core/session/inference_session.cc | 11 +- .../core/session/provider_bridge_ort.cc | 115 +----------------- 10 files changed, 275 insertions(+), 135 deletions(-) create mode 100644 onnxruntime/core/optimizer/selection_and_optimization_func.cc create mode 100644 onnxruntime/core/optimizer/selection_and_optimization_func.h diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc index 2811acdeeaeed..4784b11f3b23d 100644 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.cc +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.cc @@ -3,23 +3,77 @@ #include "core/optimizer/graph_optimizer_registry.h" #include "core/optimizer/graph_transformer_utils.h" +#include "core/optimizer/selection_and_optimization_func.h" +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" using namespace onnxruntime; using namespace ::onnxruntime::common; namespace onnxruntime { +GraphOptimizerRegistry::GraphOptimizerRegistry() { + logger_ = &logging::LoggingManager::DefaultLogger(); +} + common::Status GraphOptimizerRegistry::Register(std::unique_ptr transformer) { const auto& name = transformer->Name(); - if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer is already registered " + name); + if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end() && + name_to_transformer_map_.at(name)) { + LOGS(*logger_, WARNING) << "This optimizer is already created and registered " << name; + return Status::OK(); } name_to_transformer_map_[name] = transformer.get(); transformer_list_.push_back(std::move(transformer)); + + if (name == kCONSTANT_FOLDING_DQ) { + transformer_name_to_selection_func_[name] = ConstantFoldingDQ_selection; + } + return Status::OK(); } +common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector& optimizer_names) { + for (auto name : optimizer_names) { + if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer name is already added " + name); + } + name_to_transformer_map_[name] = nullptr; // The transformer will be instantizted only when EP requests it + } + return Status::OK(); +} + +std::unique_ptr GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map& key_value_configs) { + std::unique_ptr transformer; + if (name == kCONSTANT_FOLDING_DQ) { + const InlinedHashSet node_index_set = {}; + return std::make_unique(*cpu_ep_, false /*skip_dequantize_linear*/, + session_options_->config_options, node_index_set); + } + LOGS(*logger_, WARNING) << "Can't create optimizer " << name; + return transformer; +} + +std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name, + std::unordered_map& key_value_configs) const { + if (name_to_transformer_map_.find(name) == name_to_transformer_map_.end()) { + LOGS(*logger_, WARNING) << "Can't find optimizer " << name; + return std::nullopt; + } + + // Create and register if the transformer instance is not created. + if (!name_to_transformer_map_.at(name)) { + auto new_transformer = Get()->CreateOptimizer(name, key_value_configs); + Get()->Register(std::move(new_transformer)); + } + + auto lookup = transformer_name_to_selection_func_.find(name); + if (lookup != transformer_name_to_selection_func_.end()) { + return transformer_name_to_selection_func_.at(name); + } + return std::nullopt; +} + GraphTransformer* GraphOptimizerRegistry::GetTransformerByName(std::string& name) const { if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) { return name_to_transformer_map_.at(name); @@ -27,7 +81,7 @@ GraphTransformer* GraphOptimizerRegistry::GetTransformerByName(std::string& name return nullptr; } -// Registers all the predefined transformers for EP +// Create and register all the predefined transformers for EP common::Status GraphOptimizerRegistry::AddPredefinedOptimizers( const onnxruntime::SessionOptions& sess_options, const onnxruntime::IExecutionProvider& cpu_ep, @@ -56,6 +110,16 @@ common::Status GraphOptimizerRegistry::ApplyTransformer(Graph& graph, std::strin return Status::OK(); } +common::Status GraphOptimizerRegistry::AddCpuEpReference(onnxruntime::IExecutionProvider* cpu_ep) { + cpu_ep_ = cpu_ep; + return Status::OK(); +} + +common::Status GraphOptimizerRegistry::AddSessionOptionsReference(onnxruntime::SessionOptions* session_options) { + session_options_ = session_options; + return Status::OK(); +} + // Initialize static members std::shared_ptr onnxruntime::GraphOptimizerRegistry::graph_optimizer_registry = nullptr; std::mutex GraphOptimizerRegistry::registry_mutex; diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.h b/onnxruntime/core/optimizer/graph_optimizer_registry.h index fc425b50de44d..99ddc6542d665 100644 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.h +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.h @@ -5,16 +5,22 @@ #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" -//#include "core/common/common.h" #include "core/optimizer/graph_transformer.h" #include "core/framework/execution_providers.h" +#include "core/framework/compute_capability.h" namespace onnxruntime { +/** + * A registration/lookup class for re-usable optimizers for EPs. + */ class GraphOptimizerRegistry { public: - explicit GraphOptimizerRegistry() {} + explicit GraphOptimizerRegistry(); GraphOptimizerRegistry(const GraphOptimizerRegistry&) = delete; + /** + * Get GraphOptimizerRegistry instance as a singleton. + */ static std::shared_ptr Get() { if (!graph_optimizer_registry) { // First Check (without locking) std::lock_guard lock(registry_mutex); @@ -25,21 +31,66 @@ class GraphOptimizerRegistry { return graph_optimizer_registry; } + /** + * Register all the predefined optimizer names, only name not the optimizer instance. + * + * The optimizer will later be instantizted only when EP requests it by calling GetOptimizerByName in provider bridge. + */ + common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector& optimizer_names); + + /** + * Create and register all predefined optimizers. + */ common::Status AddPredefinedOptimizers(const onnxruntime::SessionOptions& sess_options, - const onnxruntime::IExecutionProvider& cpu_ep, - const logging::Logger& logger); + const onnxruntime::IExecutionProvider& cpu_ep, + const logging::Logger& logger); + + /** + * Create optimizer instance. + */ + std::unique_ptr CreateOptimizer(std::string& name, std::unordered_map& key_value_configs); + + /** + * Get optimizer by name. + */ + GraphTransformer* GraphOptimizerRegistry::GetTransformerByName(std::string& name) const; + /** + * Run the optimizer. + */ common::Status ApplyTransformer(Graph& graph, std::string& name, const logging::Logger& logger) const; + /** + * Register optimizer and its optimization selection function. + */ common::Status Register(std::unique_ptr transformer); - - // Get transformer by name. Return nullptr if not found. - GraphTransformer* GetTransformerByName(std::string& name) const; + + /** + * Get optimizer selection function requested by EP. If the optimizer name can't be found, return nullopt. + * + * Please note that this function also creates and registers the optimizer if its instance is not existed. + */ + std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name, + std::unordered_map& key_value_configs) const; + + /** + * Add CPU EP reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ. + */ + common::Status AddCpuEpReference(onnxruntime::IExecutionProvider* cpu_ep); + + /** + * Add Session Options reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ. + */ + common::Status AddSessionOptionsReference(onnxruntime::SessionOptions* session_options); private: InlinedVector> transformer_list_; InlinedHashMap name_to_transformer_map_; + InlinedHashMap>(const GraphViewer&)>> transformer_name_to_selection_func_; + const logging::Logger* logger_; + onnxruntime::IExecutionProvider* cpu_ep_; + onnxruntime::SessionOptions* session_options_; static std::shared_ptr graph_optimizer_registry; static std::mutex registry_mutex; diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc index 9bb09b542abbf..afbf68f8bb874 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc @@ -1,18 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include - #include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" -#include "core/optimizer/initializer.h" -#include "core/optimizer/utils.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/graph/graph_utils.h" -#include "core/optimizer/optimizer_execution_frame.h" -#include "core/optimizer/utils.h" -#include "core/framework/op_kernel.h" -#include "core/framework/tensorprotoutils.h" - -using namespace onnxruntime::common; namespace onnxruntime { diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h index 6b2f8aabdaf3e..a708f33040177 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h @@ -12,15 +12,13 @@ namespace onnxruntime { /** -@class ConstantFolding +@class ConstantFoldingDQ -Transformer that traverses the graph top-down and performs constant folding, i.e., -it statically computes parts of the graph that rely only on constant initializers. +It's the derived class from ConstantFolding. */ class ConstantFoldingDQ : public ConstantFolding { public: /*! Constant folding will not be applied to nodes that have one of initializers from excluded_initializers as input. - For pre-training, the trainable weights are those initializers to be excluded. \param execution_provider Execution provider instance to execute constant folding. */ ConstantFoldingDQ(const IExecutionProvider& execution_provider, diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.cc b/onnxruntime/core/optimizer/selection_and_optimization_func.cc new file mode 100644 index 0000000000000..a3fe45d89f928 --- /dev/null +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "selection_and_optimization_func.h" +#include "core/optimizer/graph_optimizer_registry.h" +#include "core/graph/graph_utils.h" +#include "core/framework/compute_capability.h" +#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" + +namespace onnxruntime { + +std::vector> ConstantFoldingDQ_selection(const GraphViewer& graph_viewer) { + std::vector> result; + std::unique_ptr sub_graph = std::make_unique(); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); + InitializedTensorSet constant_inputs; + const InlinedHashSet excluded_initializers; + + // Select DequantizeLinear node where all inputs are constant + for (const auto& index : node_index) { + const auto& node = graph_viewer.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + if (!graph_utils::AllNodeInputsAreConstant(graph_viewer.GetGraph(), *node, constant_inputs, excluded_initializers)) { + continue; + } + sub_graph->nodes.push_back(index); + } + + result.push_back(std::make_unique(std::move(sub_graph))); + result.back()->optimization_func = ConstantFoldingDQ_optimization; + return result; +} + +Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& optimization_cc, ComputeCapability& cc_to_update) { + std::string optimizer_name = kCONSTANT_FOLDING_DQ; + auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); + std::unordered_set original_initializers_to_remove; + std::unordered_set new_initializers_to_add; + InlinedHashSet dq_node_index_set; + + // iterate node_to_optimize to: + // 1. get original initializers to remove + // 2. add new initializers + // 3. create dq node index set + for (const auto& index : optimization_cc.sub_graph->nodes) { + auto node = graph.GetNode(index); + if (node->OpType() != "DequantizeLinear") { + continue; + } + auto input_0 = node->InputDefs()[0]; + auto output_0 = node->OutputDefs()[0]; + original_initializers_to_remove.insert(input_0->Name()); + new_initializers_to_add.insert(output_0->Name()); + dq_node_index_set.insert(index); + } + + auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); + + ConstantFoldingDQ* transformer = static_cast(optimizer_registry->GetTransformerByName(optimizer_name)); + transformer->UpdateNodeIndexSet(dq_node_index_set); + + // apply constant folding on DQ nodes + optimizer_registry->ApplyTransformer(graph, optimizer_name, *logger); + + // update the overall ComputeCapability + std::vector updated_nodes; + for (auto index : cc_to_update.sub_graph->nodes) { + if (dq_node_index_set.find(index) != dq_node_index_set.end()) { + continue; + } + updated_nodes.push_back(index); + } + cc_to_update.sub_graph->nodes = updated_nodes; + + auto original_meta_def = cc_to_update.sub_graph->GetMetaDef(); + std::unique_ptr updated_meta_def = std::make_unique(); + updated_meta_def->name = original_meta_def->name; + updated_meta_def->domain = original_meta_def->domain; + updated_meta_def->since_version = original_meta_def->since_version; + updated_meta_def->status = original_meta_def->status; + updated_meta_def->inputs = original_meta_def->inputs; + updated_meta_def->outputs = original_meta_def->outputs; + updated_meta_def->attributes = original_meta_def->attributes; + updated_meta_def->doc_string = original_meta_def->doc_string; +#if !defined(ORT_MINIMAL_BUILD) + updated_meta_def->type_and_shape_inference_function = original_meta_def->type_and_shape_inference_function; +#endif + for (auto constant_initializer : original_meta_def->constant_initializers) { + if (original_initializers_to_remove.find(constant_initializer) != original_initializers_to_remove.end()) { + continue; + } + updated_meta_def->constant_initializers.push_back(constant_initializer); + } + + for (auto constant_initializer : new_initializers_to_add) { + updated_meta_def->constant_initializers.push_back(constant_initializer); + } + + cc_to_update.sub_graph->SetMetaDef(std::move(updated_meta_def)); + + return Status::OK(); +} + + + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.h b/onnxruntime/core/optimizer/selection_and_optimization_func.h new file mode 100644 index 0000000000000..08ce605c3df2e --- /dev/null +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/compute_capability.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { +static const std::string kCONSTANT_FOLDING_DQ = "ConstantFoldingDQ"; + +// ConstantFoldingDQ selection function +std::vector> ConstantFoldingDQ_selection(const GraphViewer& graph_viewer); + +// ConstantFoldingDQ optimization function +Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& optimization_cc, ComputeCapability& cc_to_update); + + + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 16fc15ea76725..ec6cd85b4279e 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -143,6 +143,7 @@ struct ProviderHost { virtual const OrtApiBase* OrtGetApiBase() = 0; virtual Status GetOptimizerByName(const std::string& name, + std::unordered_map& key_value_configs, std::function>(const GraphViewer&)>& selection_func) = 0; virtual void* HeapAllocate(size_t size) = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 8a899884e5892..41005522572ac 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2670,7 +2670,8 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, */ std::function>(const GraphViewer&)> selection_func; - auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", selection_func); + std::unordered_map key_value_configs = {}; + auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", key_value_configs, selection_func); std::vector> selection_cc; if (selection_func) { selection_cc = selection_func(graph); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index a852c195c1377..e17562eb84872 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1845,11 +1845,20 @@ common::Status InferenceSession::Initialize() { record_runtime_optimization_produced_op_schema, *session_logger_)); - // register predefined transformers for EP + // Register predefined optimizer names for EPs. + // It won't create optimizer instance until EP later requests it by calling GetOptimizerByName in provider bridge + std::vector predefined_optimizer_names_for_ep; + predefined_optimizer_names_for_ep.push_back("ConstantFoldingDQ"); auto graph_optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); + graph_optimizer_registry->AddPredefinedOptimizerNames(predefined_optimizer_names_for_ep); + graph_optimizer_registry->AddCpuEpReference(execution_providers_.Get(onnxruntime::kCpuExecutionProvider)); + + // Don't create optimizer instances upfront. + /* graph_optimizer_registry->AddPredefinedOptimizers(session_options_, *execution_providers_. Get(onnxruntime::kCpuExecutionProvider), *session_logger_); + */ #ifdef USE_DML const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index c954141ae06c7..10b5a4138b336 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -4,6 +4,7 @@ // This is the Onnxruntime side of the bridge to allow providers to be built as a DLL // It implements onnxruntime::ProviderHost +#include #include "core/common/inlined_containers.h" #include "core/common/path_string.h" #include "core/framework/allocator_utils.h" @@ -37,8 +38,6 @@ #include "core/framework/model_metadef_id_generator.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" -#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h" -#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_c_api.h" #include "core/common/string_helper.h" @@ -215,119 +214,17 @@ struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } Status GetOptimizerByName(const std::string& name, + std::unordered_map& key_value_configs, std::function>(const GraphViewer&)>& selection_func) override { - //static const GraphTransformerManager& graph_transformer_mgr = transformer_mgr; std::string optimizer_name(name); - // Pre-defined graph transformers/optimizers - static const std::string kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ = "ConstantFoldingDQ"; - - // ConstantFoldingDQ optimization function - auto constant_folding_dq_optimization = [&](Graph& graph, const ComputeCapability& optimization_cc, ComputeCapability& cc_to_update) -> Status { - std::string optimizer_name = kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ; - auto logger = const_cast(&logging::LoggingManager::DefaultLogger()); - std::unordered_set original_initializers_to_remove; - std::unordered_set new_initializers_to_add; - InlinedHashSet dq_node_index_set; - - // iterate node_to_optimize to: - // 1. get original initializers to remove - // 2. add new initializers - // 3. create dq node index set - for (const auto& index : optimization_cc.sub_graph->nodes) { - auto node = graph.GetNode(index); - if (node->OpType() != "DequantizeLinear") { - continue; - } - auto input_0 = node->InputDefs()[0]; - auto output_0 = node->OutputDefs()[0]; - original_initializers_to_remove.insert(input_0->Name()); - new_initializers_to_add.insert(output_0->Name()); - dq_node_index_set.insert(index); - } - - auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); - - ConstantFoldingDQ* transformer = static_cast(optimizer_registry->GetTransformerByName(optimizer_name)); - transformer->UpdateNodeIndexSet(dq_node_index_set); - - // apply constant folding on DQ nodes - optimizer_registry->ApplyTransformer(graph, optimizer_name, *logger); - - // update the overall ComputeCapability - std::vector updated_nodes; - for (auto index : cc_to_update.sub_graph->nodes) { - if (dq_node_index_set.find(index) != dq_node_index_set.end()) { - continue; - } - updated_nodes.push_back(index); - } - cc_to_update.sub_graph->nodes = updated_nodes; - - auto original_meta_def = cc_to_update.sub_graph->GetMetaDef(); - std::unique_ptr updated_meta_def = std::make_unique(); - updated_meta_def->name = original_meta_def->name; - updated_meta_def->domain = original_meta_def->domain; - updated_meta_def->since_version = original_meta_def->since_version; - updated_meta_def->status = original_meta_def->status; - updated_meta_def->inputs = original_meta_def->inputs; - updated_meta_def->outputs = original_meta_def->outputs; - updated_meta_def->attributes = original_meta_def->attributes; - updated_meta_def->doc_string = original_meta_def->doc_string; -#if !defined(ORT_MINIMAL_BUILD) - updated_meta_def->type_and_shape_inference_function = original_meta_def->type_and_shape_inference_function; -#endif - for (auto constant_initializer : original_meta_def->constant_initializers) { - if (original_initializers_to_remove.find(constant_initializer) != original_initializers_to_remove.end()) { - continue; - } - updated_meta_def->constant_initializers.push_back(constant_initializer); - } + auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); + auto func = optimizer_registry->GetSelectionFunc(optimizer_name, key_value_configs); - for (auto constant_initializer : new_initializers_to_add) { - updated_meta_def->constant_initializers.push_back(constant_initializer); - } - - cc_to_update.sub_graph->SetMetaDef(std::move(updated_meta_def)); - - return Status::OK(); - }; - - // ConstantFoldingDQ selection function - auto constant_folding_dq_selection = [&](const GraphViewer& graph_viewer) -> std::vector> { - std::vector> result; - std::unique_ptr sub_graph = std::make_unique(); - const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED /*priority-based topological sort*/); - InitializedTensorSet constant_inputs; - const InlinedHashSet excluded_initializers; - - // Select DequantizeLinear node where all inputs are constant - for (const auto& index : node_index) { - const auto& node = graph_viewer.GetNode(index); - if (node->OpType() != "DequantizeLinear") { - continue; - } - if (!graph_utils::AllNodeInputsAreConstant(graph_viewer.GetGraph(), *node, constant_inputs, excluded_initializers)) { - continue; - } - sub_graph->nodes.push_back(index); - } - - result.push_back(std::make_unique(std::move(sub_graph))); - result.back()->optimization_func = constant_folding_dq_optimization; - return result; - }; - - // Optimizer lookup table - static std::unordered_map>(const GraphViewer&)>> optimizer_to_selection_function; - if (optimizer_to_selection_function.find(kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ) == optimizer_to_selection_function.end()) { - optimizer_to_selection_function[kEP_GRAPH_TRANSFORMER_CONSTANT_FOLDING_DQ] = constant_folding_dq_selection; + if (func.has_value()) { + selection_func = func.value(); } - auto look_up = optimizer_to_selection_function.find(optimizer_name); - if (look_up != optimizer_to_selection_function.end()) { - selection_func = optimizer_to_selection_function[optimizer_name]; - } return Status::OK(); }; From 3360dfd422edc3b212568572cf7b6f84b0717fd1 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 6 Feb 2025 15:22:38 -0800 Subject: [PATCH 17/24] minor modification --- onnxruntime/core/optimizer/graph_optimizer_registry.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc index 4784b11f3b23d..6caa6ad2e28f0 100644 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.cc +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.cc @@ -36,7 +36,8 @@ common::Status GraphOptimizerRegistry::Register(std::unique_ptr& optimizer_names) { for (auto name : optimizer_names) { if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer name is already added " + name); + LOGS(*logger_, WARNING) << "This transformer name is already added " << name; + return Status::OK(); } name_to_transformer_map_[name] = nullptr; // The transformer will be instantizted only when EP requests it } From 5f7da9f7f430268993bd3bbdc160920fc1091be6 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 6 Feb 2025 16:47:44 -0800 Subject: [PATCH 18/24] fix compiler error --- .../core/framework/compute_capability.h | 3 + .../optimizer/graph_optimizer_registry.cc | 59 ++++++++----------- .../core/optimizer/graph_optimizer_registry.h | 23 +++++--- .../selection_and_optimization_func.cc | 19 +++++- .../shared_library/provider_interfaces.h | 1 - .../tensorrt/tensorrt_execution_provider.cc | 3 +- .../core/session/provider_bridge_ort.cc | 3 +- 7 files changed, 61 insertions(+), 50 deletions(-) diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index 2d1d4e0e0153f..dfe8536fe983a 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -30,6 +30,9 @@ struct ComputeCapability { // Optimization: std::function std::function optimization_func; + // Optional key/value strings to configure an optimizer + std::unordered_map optimization_configs; + // optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized. // when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made. // IndexedSubGraph.nodes: diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc index 6caa6ad2e28f0..c7a702f1e4a52 100644 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.cc +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.cc @@ -15,24 +15,6 @@ GraphOptimizerRegistry::GraphOptimizerRegistry() { logger_ = &logging::LoggingManager::DefaultLogger(); } -common::Status GraphOptimizerRegistry::Register(std::unique_ptr transformer) { - const auto& name = transformer->Name(); - if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end() && - name_to_transformer_map_.at(name)) { - LOGS(*logger_, WARNING) << "This optimizer is already created and registered " << name; - return Status::OK(); - } - - name_to_transformer_map_[name] = transformer.get(); - transformer_list_.push_back(std::move(transformer)); - - if (name == kCONSTANT_FOLDING_DQ) { - transformer_name_to_selection_func_[name] = ConstantFoldingDQ_selection; - } - - return Status::OK(); -} - common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector& optimizer_names) { for (auto name : optimizer_names) { if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end()) { @@ -40,38 +22,47 @@ common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map& key_value_configs) { - std::unique_ptr transformer; +common::Status GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map& key_value_configs) { if (name == kCONSTANT_FOLDING_DQ) { const InlinedHashSet node_index_set = {}; - return std::make_unique(*cpu_ep_, false /*skip_dequantize_linear*/, - session_options_->config_options, node_index_set); + auto transformer = std::make_unique(*cpu_ep_, false /*skip_dequantize_linear*/, + session_options_->config_options, node_index_set); + Get()->Register(std::move(transformer)); + return Status::OK(); } - LOGS(*logger_, WARNING) << "Can't create optimizer " << name; - return transformer; + + LOGS(*logger_, WARNING) << "Can't create optimizer for " << name << ". It's not in the predefined optimizer list."; + return Status::OK(); } -std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name, - std::unordered_map& key_value_configs) const { - if (name_to_transformer_map_.find(name) == name_to_transformer_map_.end()) { - LOGS(*logger_, WARNING) << "Can't find optimizer " << name; - return std::nullopt; +common::Status GraphOptimizerRegistry::Register(std::unique_ptr transformer) { + const auto& name = transformer->Name(); + if (name_to_transformer_map_.find(name) != name_to_transformer_map_.end() && + name_to_transformer_map_.at(name)) { + LOGS(*logger_, WARNING) << "This optimizer is already created and registered " << name; + return Status::OK(); } - // Create and register if the transformer instance is not created. - if (!name_to_transformer_map_.at(name)) { - auto new_transformer = Get()->CreateOptimizer(name, key_value_configs); - Get()->Register(std::move(new_transformer)); - } + name_to_transformer_map_[name] = transformer.get(); + transformer_list_.push_back(std::move(transformer)); + + return Status::OK(); +} +std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const { auto lookup = transformer_name_to_selection_func_.find(name); if (lookup != transformer_name_to_selection_func_.end()) { return transformer_name_to_selection_func_.at(name); } + LOGS(*logger_, WARNING) << "Can't find selection function of " << name; return std::nullopt; } diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.h b/onnxruntime/core/optimizer/graph_optimizer_registry.h index 99ddc6542d665..844f714104028 100644 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.h +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.h @@ -46,9 +46,9 @@ class GraphOptimizerRegistry { const logging::Logger& logger); /** - * Create optimizer instance. + * Create and register optimizer. */ - std::unique_ptr CreateOptimizer(std::string& name, std::unordered_map& key_value_configs); + common::Status GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map& key_value_configs); /** * Get optimizer by name. @@ -67,12 +67,9 @@ class GraphOptimizerRegistry { common::Status Register(std::unique_ptr transformer); /** - * Get optimizer selection function requested by EP. If the optimizer name can't be found, return nullopt. - * - * Please note that this function also creates and registers the optimizer if its instance is not existed. + * Get optimizer selection function. If the optimizer name can't be found, return nullopt. */ - std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name, - std::unordered_map& key_value_configs) const; + std::optional>(const GraphViewer&)>> GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const; /** * Add CPU EP reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ. @@ -80,10 +77,20 @@ class GraphOptimizerRegistry { common::Status AddCpuEpReference(onnxruntime::IExecutionProvider* cpu_ep); /** - * Add Session Options reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ. + * Get CPU EP reference. + */ + onnxruntime::IExecutionProvider* GetCpuEpReference() const { return cpu_ep_; } + + /** + * Add session options reference from InferenceSession as it's needed for some optimizers, ex: ConstantFoldingDQ. */ common::Status AddSessionOptionsReference(onnxruntime::SessionOptions* session_options); + /** + * Get Session Options reference. + */ + onnxruntime::SessionOptions* GetSessionOptionsReference() const { return session_options_; } + private: InlinedVector> transformer_list_; InlinedHashMap name_to_transformer_map_; diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.cc b/onnxruntime/core/optimizer/selection_and_optimization_func.cc index a3fe45d89f928..6592631ff07df 100644 --- a/onnxruntime/core/optimizer/selection_and_optimization_func.cc +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.cc @@ -57,9 +57,22 @@ Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& opt } auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); - - ConstantFoldingDQ* transformer = static_cast(optimizer_registry->GetTransformerByName(optimizer_name)); - transformer->UpdateNodeIndexSet(dq_node_index_set); + + // ConstantFoldingDQ optimizer doesn't need the key/value strings. + std::unordered_map key_value_configs = optimization_cc.optimization_configs; + + // Don't use CreateOptimizer as ConstantFoldingDQ needs dq_node_index_set for instantiation. + // optimizer_registry->CreateOptimizer(optimizer_name, key_value_configs); + + // Create ConstantFoldingDQ optimizer if it's not existed. + if (!optimizer_registry->GetTransformerByName(optimizer_name)) { + auto transformer = std::make_unique(*optimizer_registry->GetCpuEpReference(), + false /*skip_dequantize_linear*/, + optimizer_registry->GetSessionOptionsReference()->config_options, + dq_node_index_set); + optimizer_registry->Register(std::move(transformer)); + } + // apply constant folding on DQ nodes optimizer_registry->ApplyTransformer(graph, optimizer_name, *logger); diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index ec6cd85b4279e..16fc15ea76725 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -143,7 +143,6 @@ struct ProviderHost { virtual const OrtApiBase* OrtGetApiBase() = 0; virtual Status GetOptimizerByName(const std::string& name, - std::unordered_map& key_value_configs, std::function>(const GraphViewer&)>& selection_func) = 0; virtual void* HeapAllocate(size_t size) = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 41005522572ac..8a899884e5892 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2670,8 +2670,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, */ std::function>(const GraphViewer&)> selection_func; - std::unordered_map key_value_configs = {}; - auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", key_value_configs, selection_func); + auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", selection_func); std::vector> selection_cc; if (selection_func) { selection_cc = selection_func(graph); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 10b5a4138b336..46346e9457f21 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -214,12 +214,11 @@ struct ProviderHostImpl : ProviderHost { const OrtApiBase* OrtGetApiBase() override { return ::OrtGetApiBase(); } Status GetOptimizerByName(const std::string& name, - std::unordered_map& key_value_configs, std::function>(const GraphViewer&)>& selection_func) override { std::string optimizer_name(name); auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); - auto func = optimizer_registry->GetSelectionFunc(optimizer_name, key_value_configs); + auto func = optimizer_registry->GetSelectionFunc(optimizer_name); if (func.has_value()) { selection_func = func.value(); From e610bc88a06b7739a7c365b684cc43458452872d Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 6 Feb 2025 16:50:18 -0800 Subject: [PATCH 19/24] remove unnecessary member function --- .../optimizer/qdq_transformer/constant_folding_dq_node.cc | 5 ----- .../optimizer/qdq_transformer/constant_folding_dq_node.h | 1 - 2 files changed, 6 deletions(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc index afbf68f8bb874..a2f46d6ae693c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.cc @@ -23,9 +23,4 @@ bool ConstantFoldingDQ::AllowConstantFolding(const Node& node) const { return false; } -Status ConstantFoldingDQ::UpdateNodeIndexSet(InlinedHashSet& node_index_set) { - node_index_set_ = node_index_set; - return Status::OK(); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h index a708f33040177..fac31ac00d143 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h @@ -29,7 +29,6 @@ class ConstantFoldingDQ : public ConstantFolding { const InlinedHashSet& excluded_initializers = {}) noexcept; bool AllowConstantFolding(const Node& node) const; - Status UpdateNodeIndexSet(InlinedHashSet& node_index_set); private: InlinedHashSet node_index_set_; From e95f2c32005ffb3a095a875285f562baed203951 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 7 Feb 2025 00:57:47 +0000 Subject: [PATCH 20/24] lintrunner -a --- .../core/framework/compute_capability.h | 3 +-- onnxruntime/core/optimizer/constant_folding.h | 4 +-- .../optimizer/graph_optimizer_registry.cc | 4 +-- .../core/optimizer/graph_optimizer_registry.h | 10 +++---- .../core/optimizer/graph_transformer_mgr.cc | 2 +- .../constant_folding_dq_node.h | 2 +- .../selection_and_optimization_func.cc | 7 ++--- .../selection_and_optimization_func.h | 2 -- .../tensorrt/tensorrt_execution_provider.cc | 26 +++++++++---------- .../tensorrt/tensorrt_execution_provider.h | 2 +- .../tensorrt_execution_provider_helper.cc | 8 +++--- 11 files changed, 32 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/framework/compute_capability.h b/onnxruntime/core/framework/compute_capability.h index dfe8536fe983a..0cd41c60475f6 100644 --- a/onnxruntime/core/framework/compute_capability.h +++ b/onnxruntime/core/framework/compute_capability.h @@ -7,7 +7,6 @@ #include "core/graph/indexed_sub_graph.h" #include "core/graph/graph.h" - namespace onnxruntime { // A structure encodes a subgraph and the method to run it. struct ComputeCapability { @@ -24,7 +23,7 @@ struct ComputeCapability { ComputeCapability(std::unique_ptr t_sub_graph) : sub_graph(std::move(t_sub_graph)) {} - + // Optional function to optimize this ComputeCapability. // This will be called by ORT once the ComputeCapability is assigned to the EP // Optimization: std::function diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 423dbeae00ffb..0c0ca0346e6b4 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -28,7 +28,7 @@ class ConstantFolding : public GraphTransformer { const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; -protected: + protected: /** * Same as the constructor above but with a name provided by derived class. */ @@ -41,7 +41,7 @@ class ConstantFolding : public GraphTransformer { /** * Derived class can implement this virtual function to limit the nodes that can be constant folded. */ - virtual bool AllowConstantFolding(const Node& node) const { return true; } + virtual bool AllowConstantFolding(const Node& node) const { return true; } private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/optimizer/graph_optimizer_registry.cc b/onnxruntime/core/optimizer/graph_optimizer_registry.cc index c7a702f1e4a52..dc264e37c741c 100644 --- a/onnxruntime/core/optimizer/graph_optimizer_registry.cc +++ b/onnxruntime/core/optimizer/graph_optimizer_registry.cc @@ -22,7 +22,7 @@ common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector Get() { - if (!graph_optimizer_registry) { // First Check (without locking) + if (!graph_optimizer_registry) { // First Check (without locking) std::lock_guard lock(registry_mutex); - if (!graph_optimizer_registry) { // Second Check (with locking) + if (!graph_optimizer_registry) { // Second Check (with locking) graph_optimizer_registry = std::make_shared(); } } @@ -33,7 +33,7 @@ class GraphOptimizerRegistry { /** * Register all the predefined optimizer names, only name not the optimizer instance. - * + * * The optimizer will later be instantizted only when EP requests it by calling GetOptimizerByName in provider bridge. */ common::Status GraphOptimizerRegistry::AddPredefinedOptimizerNames(std::vector& optimizer_names); @@ -46,7 +46,7 @@ class GraphOptimizerRegistry { const logging::Logger& logger); /** - * Create and register optimizer. + * Create and register optimizer. */ common::Status GraphOptimizerRegistry::CreateOptimizer(std::string& name, std::unordered_map& key_value_configs); @@ -56,7 +56,7 @@ class GraphOptimizerRegistry { GraphTransformer* GraphOptimizerRegistry::GetTransformerByName(std::string& name) const; /** - * Run the optimizer. + * Run the optimizer. */ common::Status ApplyTransformer(Graph& graph, std::string& name, const logging::Logger& logger) const; diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index f56bc3bfab15f..d67eb12e2a994 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -50,7 +50,7 @@ common::Status GraphTransformerManager::ApplyTransformer(Graph& graph, std::stri if (!transformer) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This transformer is not registered " + name); } - + bool modified = false; for (unsigned step = 0; step < steps_; ++step) { if (step > 0 && transformer->ShouldOnlyApplyOnce()) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h index fac31ac00d143..b5341f380bc2c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h +++ b/onnxruntime/core/optimizer/qdq_transformer/constant_folding_dq_node.h @@ -23,7 +23,7 @@ class ConstantFoldingDQ : public ConstantFolding { */ ConstantFoldingDQ(const IExecutionProvider& execution_provider, bool skip_dequantize_linear, - const ConfigOptions& config_options, + const ConfigOptions& config_options, const InlinedHashSet& node_index_set, const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.cc b/onnxruntime/core/optimizer/selection_and_optimization_func.cc index 6592631ff07df..c703aad62114d 100644 --- a/onnxruntime/core/optimizer/selection_and_optimization_func.cc +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.cc @@ -57,10 +57,10 @@ Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& opt } auto optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); - + // ConstantFoldingDQ optimizer doesn't need the key/value strings. std::unordered_map key_value_configs = optimization_cc.optimization_configs; - + // Don't use CreateOptimizer as ConstantFoldingDQ needs dq_node_index_set for instantiation. // optimizer_registry->CreateOptimizer(optimizer_name, key_value_configs); @@ -72,7 +72,6 @@ Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& opt dq_node_index_set); optimizer_registry->Register(std::move(transformer)); } - // apply constant folding on DQ nodes optimizer_registry->ApplyTransformer(graph, optimizer_name, *logger); @@ -116,6 +115,4 @@ Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& opt return Status::OK(); } - - } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.h b/onnxruntime/core/optimizer/selection_and_optimization_func.h index 08ce605c3df2e..2b58b3e3da18a 100644 --- a/onnxruntime/core/optimizer/selection_and_optimization_func.h +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.h @@ -15,6 +15,4 @@ std::vector> ConstantFoldingDQ_selection(cons // ConstantFoldingDQ optimization function Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& optimization_cc, ComputeCapability& cc_to_update); - - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 8a899884e5892..bce845b52c2b2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2554,7 +2554,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } bool early_termination = false; - //supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); + // supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); supported_nodes_vector = parser_nodes_vector; if (early_termination) { supported_nodes_vector.clear(); @@ -2656,15 +2656,15 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } } - /** + /** * Enable EP related L2+ graph optimizations with steps: - * + * * 1. Call provider bridge API to lookup pre-defined optimizer by name and get selection function * - Run selection function to get selection ComputeCapability * - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization - * - * - * + * + * + * * Current available optimizations: * - (ConstantFoldingDQ) constant folding on DQ nodes -> Dequantize INT32, UINT16, INT16 constant to FP32. */ @@ -2676,10 +2676,10 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, selection_cc = selection_func(graph); } - std::unordered_set trt_selection_node_set; // The qualified dq nodes selected by TRT EP - std::unordered_map consumer_to_dq; // consumer node -> dq node + std::unordered_set trt_selection_node_set; // The qualified dq nodes selected by TRT EP + std::unordered_map consumer_to_dq; // consumer node -> dq node // Note: The NodeIndex here is the node index in the graph, not the index in node vector in supported_nodes_vector. - + SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq); // Include nodes that are filtered out by TRT parser. @@ -2692,7 +2692,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } auto dq_node_index = consumer_to_dq[node_index[index]]; - + // Check if DQ node is included in one of the subgraphs auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { for (auto& node_vector : supported_nodes_vector) { @@ -2718,7 +2718,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, supported_node_vector.first.push_back(static_cast(idx)); auto node = graph.GetNode(dq_node_index); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; - } + } } } }; @@ -2730,13 +2730,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, // TODO: Use consumer_to_dq table to include DQ node that is filtered out by TRT parser. std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); - + // add optimization ComputeCapability to node_to_optimize for (auto& cc : selection_cc) { std::unique_ptr optimization_cc = CreateOptimizationComputeCapability(cc.get(), trt_selection_node_set, compute_capability.get()); compute_capability->add_nodes_to_optimize(std::move(optimization_cc)); } - + result.push_back(std::move(compute_capability)); number_of_trt_nodes += static_cast(group.first.size()); subgraph_index++; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 4946058e16d75..5889ff9960cd6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -600,7 +600,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { void SelectQualifiedDQNode(const GraphViewer& graph, std::unordered_set& selection_node_set, std::unordered_map& consumer_to_dq) const; - + /** * This function returns an optimization ComputeCapability that is limited to: * 1. the DQ nodes in this individual TRT ComputeCapability diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index 726988bef1552..aeba2854b944c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -261,8 +261,8 @@ void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { /** * This is the helper function for ConstantFoldingDQ graph transformer. - * - * It selects the qualified/required DQ node to be optimized as well as provides a mapping table + * + * It selects the qualified/required DQ node to be optimized as well as provides a mapping table * to help TRT EP later include the DQ node which is filtered out by TRT parser. */ void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, @@ -286,7 +286,7 @@ void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, // 3. The fist input of DQ is constant initializer. // 4. The data type of initializer is INT32, UINT16 or INT16 // 4. X should be Gemm, Conv or LayerNormalization ? - if (node->OpType() == "DequantizeLinear" && + if (node->OpType() == "DequantizeLinear" && node->GetOutputEdgesCount() == 1 && (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || data_type == ONNX_NAMESPACE::TensorProto_DataType_INT16 || data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) && constant_initializer) { @@ -303,7 +303,7 @@ void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, * This function returns an optimization ComputeCapability that is limited to: * 1. the DQ nodes in this individual TRT ComputeCapability * 2. the DQ nodes that are qualified and selected by TRT EP - * + * * It also needs to make sure the DQ nodes is a subset of the complete list of DQ nodes to optimize in original selection ComputeCapability. * Finally, copy the optimization function from the original selection ComputeCapability. */ From bad19b9c9efa9e24c919d7e6f9af2630ee7885bb Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 6 Feb 2025 17:25:22 -0800 Subject: [PATCH 21/24] handle status --- .../providers/tensorrt/tensorrt_execution_provider.cc | 11 ++++++++--- onnxruntime/core/session/provider_bridge_ort.cc | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index bce845b52c2b2..48fbb6c6fedee 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2670,10 +2670,15 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, */ std::function>(const GraphViewer&)> selection_func; - auto status = g_host->GetOptimizerByName("ConstantFoldingDQ", selection_func); std::vector> selection_cc; - if (selection_func) { - selection_cc = selection_func(graph); + std::string optimizer_name = "ConstantFoldingDQ"; + auto status = g_host->GetOptimizerByName(optimizer_name, selection_func); + if (status == Status::OK()) { + if (selection_func) { + selection_cc = selection_func(graph); + } + } else { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Can't get optimizer " << optimizer_name; } std::unordered_set trt_selection_node_set; // The qualified dq nodes selected by TRT EP diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 46346e9457f21..483ef09468165 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -222,8 +222,9 @@ struct ProviderHostImpl : ProviderHost { if (func.has_value()) { selection_func = func.value(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get optimizer " + optimizer_name); } - return Status::OK(); }; From d4968cb3f976ce61e30e0f4ded44df46972817f9 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 7 Feb 2025 14:28:55 -0800 Subject: [PATCH 22/24] remove unnecessary code --- onnxruntime/core/framework/graph_partitioner.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index d7d52798fc3d7..26d1e6cb79e16 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -12,7 +12,6 @@ #include "core/framework/kernel_lookup.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" -#include "core/optimizer/graph_optimizer_registry.h" #include "core/graph/function.h" #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" @@ -259,7 +258,6 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, kernel_registries_for_ep, kernel_registry_mgr.GetKernelTypeStrResolver(), logger}; - auto graph_optimizer_registry = onnxruntime::GraphOptimizerRegistry::Get(); // TODO: Provide EP with a capability to look inside the functions. capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); From df5aca92d6665693c1b1ebe4b0a0c9a13904b4fb Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 9 Feb 2025 09:21:48 -0800 Subject: [PATCH 23/24] add GetMutableMetaDef --- .../core/graph/indexed_sub_graph.h | 6 +++++ .../selection_and_optimization_func.cc | 24 ++++++------------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/graph/indexed_sub_graph.h b/include/onnxruntime/core/graph/indexed_sub_graph.h index c57db41254159..6d86f5df56ba0 100644 --- a/include/onnxruntime/core/graph/indexed_sub_graph.h +++ b/include/onnxruntime/core/graph/indexed_sub_graph.h @@ -70,6 +70,12 @@ struct IndexedSubGraph { return meta_def_.get(); } + /** Gets the mutable meta definition needed to represent this subgraph as a FunctionProto. + @returns MetaDef instance if it has been set. nullptr if not. */ + MetaDef* GetMutableMetaDef() const { + return meta_def_.get(); + } + private: // subgraph meta definition. std::unique_ptr meta_def_; diff --git a/onnxruntime/core/optimizer/selection_and_optimization_func.cc b/onnxruntime/core/optimizer/selection_and_optimization_func.cc index c703aad62114d..176070f46f671 100644 --- a/onnxruntime/core/optimizer/selection_and_optimization_func.cc +++ b/onnxruntime/core/optimizer/selection_and_optimization_func.cc @@ -86,31 +86,21 @@ Status ConstantFoldingDQ_optimization(Graph& graph, const ComputeCapability& opt } cc_to_update.sub_graph->nodes = updated_nodes; - auto original_meta_def = cc_to_update.sub_graph->GetMetaDef(); - std::unique_ptr updated_meta_def = std::make_unique(); - updated_meta_def->name = original_meta_def->name; - updated_meta_def->domain = original_meta_def->domain; - updated_meta_def->since_version = original_meta_def->since_version; - updated_meta_def->status = original_meta_def->status; - updated_meta_def->inputs = original_meta_def->inputs; - updated_meta_def->outputs = original_meta_def->outputs; - updated_meta_def->attributes = original_meta_def->attributes; - updated_meta_def->doc_string = original_meta_def->doc_string; -#if !defined(ORT_MINIMAL_BUILD) - updated_meta_def->type_and_shape_inference_function = original_meta_def->type_and_shape_inference_function; -#endif - for (auto constant_initializer : original_meta_def->constant_initializers) { + auto meta_def = cc_to_update.sub_graph->GetMutableMetaDef(); + std::vector updated_constant_initializers; + + for (auto constant_initializer : meta_def->constant_initializers) { if (original_initializers_to_remove.find(constant_initializer) != original_initializers_to_remove.end()) { continue; } - updated_meta_def->constant_initializers.push_back(constant_initializer); + updated_constant_initializers.push_back(constant_initializer); } for (auto constant_initializer : new_initializers_to_add) { - updated_meta_def->constant_initializers.push_back(constant_initializer); + updated_constant_initializers.push_back(constant_initializer); } - cc_to_update.sub_graph->SetMetaDef(std::move(updated_meta_def)); + meta_def->constant_initializers = updated_constant_initializers; return Status::OK(); } From 60d95993839bbe48ca6da8dab3ea5f3e401d1088 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 9 Feb 2025 15:37:17 -0800 Subject: [PATCH 24/24] update TRT EP --- .../tensorrt/tensorrt_execution_provider.cc | 59 ++++--------------- .../tensorrt/tensorrt_execution_provider.h | 8 +++ .../tensorrt_execution_provider_helper.cc | 59 ++++++++++++++++++- 3 files changed, 77 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 48fbb6c6fedee..55f9532076f72 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2554,8 +2554,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } bool early_termination = false; - // supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); - supported_nodes_vector = parser_nodes_vector; + supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); if (early_termination) { supported_nodes_vector.clear(); } @@ -2660,13 +2659,13 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, * Enable EP related L2+ graph optimizations with steps: * * 1. Call provider bridge API to lookup pre-defined optimizer by name and get selection function - * - Run selection function to get selection ComputeCapability - * - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization + * 2. Run selection function to get selection ComputeCapability + - ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization * * * * Current available optimizations: - * - (ConstantFoldingDQ) constant folding on DQ nodes -> Dequantize INT32, UINT16, INT16 constant to FP32. + * - (ConstantFoldingDQ) constant folding on DQ nodes, i.e. dequantize INT32, UINT16, INT16 constant to FP32. */ std::function>(const GraphViewer&)> selection_func; @@ -2687,52 +2686,16 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, SelectQualifiedDQNode(graph, trt_selection_node_set, consumer_to_dq); - // Include nodes that are filtered out by TRT parser. - auto update_supported_node_vector = [&](SubGraph_t& supported_node_vector, SubGraphCollection_t& supported_nodes_vector) -> void { - if (!consumer_to_dq.empty()) { - const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); - for (auto index : supported_node_vector.first) { - if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) { - continue; - } - - auto dq_node_index = consumer_to_dq[node_index[index]]; - - // Check if DQ node is included in one of the subgraphs - auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { - for (auto& node_vector : supported_nodes_vector) { - if (!node_vector.second) { - continue; - } - for (auto index : node_vector.first) { - if (node_index[index] == node_idx) { - return true; - } - } - } - return false; - }; - if (in_the_subgraph_collection(dq_node_index)) { - continue; - } - // Find the iterator pointing to the target element - auto it = std::find(node_index.begin(), node_index.end(), dq_node_index); - if (it != node_index.end()) { - // Calculate the index - int idx = std::distance(node_index.begin(), it); - supported_node_vector.first.push_back(static_cast(idx)); - auto node = graph.GetNode(dq_node_index); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; - } - } - } - }; - // Create ComputeCapability int number_of_trt_nodes = 0, subgraph_index = 0; - for (const auto& group : supported_nodes_vector) { + for (auto& group : supported_nodes_vector) { if (!group.first.empty()) { - // TODO: Use consumer_to_dq table to include DQ node that is filtered out by TRT parser. + + if (!selection_cc.empty()) { + // Include DQ nodes that are filtered out by TRT parser + UpdateSupportedNodeVectorForDQ(graph, group, supported_nodes_vector, consumer_to_dq); + } + std::unique_ptr sub_graph = GetSubGraph(group, graph, model_hash, subgraph_index); auto compute_capability = ComputeCapability::Create(std::move(sub_graph)); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 5889ff9960cd6..45b15368ac608 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -612,5 +612,13 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unique_ptr CreateOptimizationComputeCapability(ComputeCapability* selection_cc, std::unordered_set& trt_selection_node_set, ComputeCapability* trt_cc) const; + /** + * This function helps add back the DQ nodes that are filtered out by TRT parser. + * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. + */ + void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, + SubGraph_t& supported_node_vector, + SubGraphCollection_t& supported_nodes_vector, + std::unordered_map consumer_to_dq) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index aeba2854b944c..702bc6108b496 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -293,7 +293,7 @@ void TensorrtExecutionProvider::SelectQualifiedDQNode(const GraphViewer& graph, const Node& consumer_node = *node->OutputNodesBegin(); selection_node_set.insert(index); consumer_to_dq[consumer_node.Index()] = index; - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " < -" << node->Name(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << consumer_node.Name() << " <- " << node->Name(); } } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Total " << selection_node_set.size() << " DequantizeLinear node(s) are selected."; @@ -330,4 +330,61 @@ std::unique_ptr TensorrtExecutionProvider::CreateOptimization compute_capability->copy_optimization_func(selection_cc); return compute_capability; } + +/** + * This function helps add back the DQ nodes that are filtered out by TRT parser. + * The reason is the DQ nodes can be optimized and dequantized by applying ConstantFoldingDQ optimizer by ORT L2+ optimization. + */ +void TensorrtExecutionProvider::UpdateSupportedNodeVectorForDQ(const GraphViewer& graph, + SubGraph_t& supported_node_vector, + SubGraphCollection_t& supported_nodes_vector, + std::unordered_map consumer_to_dq) const { + if (consumer_to_dq.empty()) { + return; + } + + if (!supported_node_vector.second) { + return; + } + + const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); + auto supported_nodes = supported_node_vector.first; + for (auto index : supported_nodes) { + if (consumer_to_dq.find(node_index[index]) == consumer_to_dq.end()) { + continue; + } + + auto dq_node_index = consumer_to_dq[node_index[index]]; + + // Check if DQ node is included in one of the subgraphs + auto in_the_subgraph_collection = [&](NodeIndex node_idx) -> bool { + for (auto& node_vector : supported_nodes_vector) { + if (!node_vector.second) { + continue; + } + for (auto i : node_vector.first) { + if (node_index[i] == node_idx) { + return true; + } + } + } + return false; + }; + + // If the DQ node is already in the subgraph, do nothing. + if (in_the_subgraph_collection(dq_node_index)) { + continue; + } + + // Find the iterator pointing to the target element + auto it = std::find(node_index.begin(), node_index.end(), dq_node_index); + if (it != node_index.end()) { + // Calculate the index + int idx = std::distance(node_index.begin(), it); + supported_node_vector.first.push_back(static_cast(idx)); + auto node = graph.GetNode(dq_node_index); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << node->Name() << " is included which is filtered out by TRT parser."; + } + } +} } // namespace onnxruntime