Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling L2+ Optimizations for EPs #23517

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct ComputeCapability;
class KernelRegistry;
struct KernelCreateInfo;
class Node;
class GraphTransformerManager;
} // namespace onnxruntime
#else
#include <memory>
Expand Down Expand Up @@ -128,9 +129,16 @@ class IExecutionProvider {
For kernels registered in a kernel registry, `kernel_lookup` must be used
to find a matching kernel for this EP.
*/
/*
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& kernel_lookup) const;
*/

virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& kernel_lookup,
const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const;

/**
Get kernel registry per execution provider type.
Expand Down
6 changes: 6 additions & 0 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);

InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForEP(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
const logging::Logger& logger);

#endif // !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/framework/compute_capability.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
// Licensed under the MIT License.

#pragma once
#include <functional>
#include "core/common/common.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/graph/graph.h"


namespace onnxruntime {
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// A structure encodes a subgraph and the method to run it.
Expand All @@ -21,5 +24,19 @@ struct ComputeCapability {

ComputeCapability(std::unique_ptr<IndexedSubGraph> t_sub_graph)
: sub_graph(std::move(t_sub_graph)) {}

// Optional function to optimize this ComputeCapability.
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
// This will be called by ORT once the ComputeCapability is assigned to the EP
// Optimization: std::function<Status(const Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update)>
std::function<Status(Graph&, const ComputeCapability&, ComputeCapability&)> optimization_func;

// optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized.
// when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made.
// IndexedSubGraph.nodes:
// - update based on RemovedNode/AddNode calls
// IndexedSubGraph.MetaDef (if present):
// - inputs and outputs will be unchanged
// - constant_initializers MAY change if we constant fold an initializer during optimization
std::vector<std::unique_ptr<ComputeCapability>> nodes_to_optimize;
};
} // namespace onnxruntime
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace onnxruntime {

std::vector<std::unique_ptr<ComputeCapability>>
IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const IKernelLookup& kernel_lookup) const {
const IKernelLookup& kernel_lookup,
const onnxruntime::GraphTransformerManager& graph_transformer_mgr) const {
std::vector<std::unique_ptr<ComputeCapability>> result;
for (const auto& node : graph.Nodes()) {
if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node);
Expand Down
56 changes: 46 additions & 10 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct PartitionParams {
std::reference_wrapper<int> fused_node_unique_id;
std::reference_wrapper<const layout_transformation::TransformLayoutFunction> transform_layout_function;
std::reference_wrapper<const layout_transformation::DebugGraphFn> debug_graph_fn;
std::reference_wrapper<const onnxruntime::GraphTransformerManager> graph_transformer_manager;
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
};
} // namespace
Expand Down Expand Up @@ -130,13 +131,16 @@ struct GetCapabilityForEPParams {
GraphPartitioner::Mode mode;
std::reference_wrapper<const layout_transformation::TransformLayoutFunction> transform_layout;
std::reference_wrapper<const layout_transformation::DebugGraphFn> debug_graph_fn;
std::reference_wrapper<const onnxruntime::GraphTransformerManager> graph_transformer_manager;
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
};

auto get_capabilities = [](const IExecutionProvider& ep,
const GraphViewer& graph_viewer,
const IExecutionProvider::IKernelLookup& kernel_lookup) {
auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup);
const IExecutionProvider::IKernelLookup& kernel_lookup,
const onnxruntime::GraphTransformerManager& graph_transformer_manager) {
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
capabilities = ep.GetCapability(graph_viewer, kernel_lookup, graph_transformer_manager);

// In theory an EP could return an empty capability. Remove those.
capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(),
Expand Down Expand Up @@ -170,10 +174,11 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l

auto& graph = params.graph.get();
auto& capabilities = params.capabilities.get();
auto& graph_transformer_manager = params.graph_transformer_manager.get();

{
const GraphViewer graph_viewer(graph);
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_manager);

if (capabilities.empty()) {
return Status::OK();
Expand Down Expand Up @@ -211,7 +216,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
capabilities.clear();

const GraphViewer graph_viewer(graph);
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_manager);

// all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities
InlinedHashSet<NodeIndex> new_nodes_in_capabilities;
Expand Down Expand Up @@ -248,6 +253,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
// It also does not perform layout transformation. This will be done during normal partitioning.
static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
const KernelRegistryManager& kernel_registry_mgr,
const GraphTransformerManager& graph_transformer_mgr,
const IExecutionProvider& current_ep,
const logging::Logger& logger,
std::vector<std::unique_ptr<ComputeCapability>>& capabilities) {
Expand All @@ -260,7 +266,7 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
logger};

// TODO: Provide EP with a capability to look inside the functions.
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, graph_transformer_mgr);

return Status::OK();
}
Expand All @@ -280,11 +286,14 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
IExecutionProvider::FusionStyle fusion_style,
const std::string& provider_type,
GraphPartitioner::Mode mode,
int& fused_node_unique_id) {
int& fused_node_unique_id,
bool* subgraph_assigned_to_ep) {
Node* result = nullptr;
*subgraph_assigned_to_ep = false;

if (nullptr == capability.GetMetaDef()) {
TryAssignSingleNode(graph, capability, provider_type);
*subgraph_assigned_to_ep = true;
} else {
// The <provider> can run a fused <sub_graph> in the <graph>.

Expand Down Expand Up @@ -347,12 +356,17 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
}
}
}
*subgraph_assigned_to_ep = true;
}
}

return result;
}

static Status TransformGraph(Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update) {

}

// for the current EP, recursively iterate through the Graph and any nested subgraphs (recursion is bottom-up).
// assign any nodes to the EP that are currently unassigned, and that the EP can handle.
static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
Expand All @@ -363,6 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
int& fused_node_unique_id,
const layout_transformation::TransformLayoutFunction& transform_layout_fn,
const layout_transformation::DebugGraphFn& debug_graph_fn,
const onnxruntime::GraphTransformerManager& graph_transfomer_manager,
const logging::Logger& logger) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
Expand All @@ -377,7 +392,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
// we pass through the FuncManager from the top level graph
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
transform_layout_fn, debug_graph_fn, logger));
transform_layout_fn, debug_graph_fn, graph_transfomer_manager, logger));
}
}

Expand All @@ -400,7 +415,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
std::ref(capabilities),
mode,
std::cref(transform_layout_fn),
std::cref(debug_graph_fn)};
std::cref(debug_graph_fn),
std::cref(graph_transfomer_manager)};

ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
if (capabilities.empty()) {
Expand All @@ -426,7 +442,20 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
entry->sub_graph->GetMetaDef() != nullptr;
}));
for (auto& capability : capabilities) {
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
bool subgraph_assigned_to_ep = false;
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id, &subgraph_assigned_to_ep);

// If the subgraph is assigned to the EP and the ComputeCapability has nodes_to_optimize,
// run EP related optimizations and update ComputeCapability.
if (subgraph_assigned_to_ep && !capability->nodes_to_optimize.empty()) {
for (auto& optimization_cc : capability->nodes_to_optimize) {
if (optimization_cc->optimization_func) {
optimization_cc->optimization_func(graph, *optimization_cc, *capability);
// #TODO: Handle nested optimization ComputeCapability
}
}
}

if (n != nullptr) {
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) {
Expand Down Expand Up @@ -562,6 +591,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {

static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_mgr,
const GraphTransformerManager& graph_transformer_mgr,
Graph& graph,
const logging::Logger& logger,
InlinedHashSet<std::string>& not_inlined,
Expand All @@ -578,6 +608,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
// we pass through the FuncManager from the top level graph
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_mgr,
graph_transformer_mgr,
*subgraph,
logger,
not_inlined,
Expand All @@ -603,7 +634,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
InlinedHashSet<NodeIndex> claimed_by_ep;
for (const auto& ep : execution_providers) {
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger,
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, graph_transformer_mgr, * ep, logger,
capabilities));
for (auto& capability : capabilities) {
const auto& nodes = capability->sub_graph->nodes;
Expand Down Expand Up @@ -743,6 +774,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
auto& fused_kernel_registry = partition_params.fused_kernel_registry.get();
auto& fused_node_unique_id = partition_params.fused_node_unique_id.get();
const auto& transform_layout_function = partition_params.transform_layout_function;
const auto& graph_transformer_manager = partition_params.graph_transformer_manager;

do {
// process full graph with each EP
Expand All @@ -751,6 +783,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
fused_kernel_registry, *ep, mode, fused_node_unique_id,
transform_layout_function,
partition_params.debug_graph_fn,
graph_transformer_manager,
logger));
}

Expand Down Expand Up @@ -802,6 +835,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
GraphPartitioner::Mode::kOrtFormatLoad,
std::cref(partition_params.transform_layout_function),
std::cref(partition_params.debug_graph_fn),
std::cref(partition_params.graph_transformer_manager),
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
};
// clang-format on
Expand Down Expand Up @@ -917,6 +951,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
size_t inlined_count = 0;
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_manager,
graph_transformer_mgr_,
graph,
logger,
not_inlined,
Expand Down Expand Up @@ -975,6 +1010,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
std::ref(fused_node_unique_id),
std::cref(transform_layout_function),
std::cref(debug_graph_fn),
std::cref(graph_transformer_mgr_),
};

#else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/graph/graph.h"
#include "core/framework/fuse_nodes_funcs.h"
#include "core/framework/transform_layout_functions.h"
#include "core/optimizer/graph_transformer_mgr.h"

namespace onnxruntime {

Expand All @@ -24,8 +25,9 @@ class GraphPartitioner {
};

// The order of providers represents the user preference.
GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers)
GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const GraphTransformerManager& graph_transformer_mgr, const ExecutionProviders& providers)
: kernel_registry_mgr_(kernel_registry_mgr),
graph_transformer_mgr_(graph_transformer_mgr),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of wiring in the GraphTransformerManager would it be better to add a new standalone class that provides lookup-based access to a set of L2 optimizers that are directly usable by an EP? That would decouple the GraphPartitioner from the optimizer registration/lookup a bit more. I don't think GraphTransformerManager is providing value in this case as I don't think we need to loop.

The registration/lookup class for re-usable optimizers could be a singleton with a static 'create' method that calls GenerateTransformersForEP and saves the returned list. We could have InferenceSession call the 'create' method to provide any other required things like the CPU allocator, and potentially apply the optimization level when doing so if we want to do that on the ORT side. GraphPartitioner could call a static 'get' method to get the instance so we don't need to wire it through from the inference session.

Copy link
Contributor Author

@chilo-ms chilo-ms Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a new standalone class called GraphOptimizerRegistry to mainly provide registration and lookup access.
This lookup class is a singleton and its 'create' method is being called during InferenceSession initialization and it calls GenerateTransformersForEP and saves the returned list. BTW, not sure this 'create' function needs to be static?

Given that this lookup instance is a singleton using static variable, provider bridge now can access this instance, which means we don't need to get the instance at graph partitioner and feed to GetCapability. Also, we don't need change GetCapability's signature to add a new parameter (i.e. GraphOptimizerRegistry& ).

providers_(providers) {
}

Expand Down Expand Up @@ -64,6 +66,7 @@ class GraphPartitioner {

KernelRegistryManager& kernel_registry_mgr_;
const ExecutionProviders& providers_;
const GraphTransformerManager& graph_transformer_mgr_;
};

} // namespace onnxruntime
15 changes: 14 additions & 1 deletion onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider,
execution_provider_(execution_provider) {
}

ConstantFolding::ConstantFolding(const std::string& name,
const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers,
const InlinedHashSet<std::string>& excluded_initializers) noexcept
: GraphTransformer(name, compatible_execution_providers),
skip_dequantize_linear_(skip_dequantize_linear),
config_options_(config_options),
excluded_initializers_(excluded_initializers),
execution_provider_(execution_provider) {
}

// We need to handle a Shape node separately as the input doesn't need to be a constant initializer for
// Shape to be able to be constant folded.
static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
Expand Down Expand Up @@ -144,7 +157,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,

for (NodeIndex i : order) {
auto* node = graph.GetNode(i);
if (!node) {
if (!node || !AllowConstantFolding(*node)) {
continue;
}

Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/optimizer/constant_folding.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ class ConstantFolding : public GraphTransformer {
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;

/* Same as above but with a name provided by derived class.
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved
*/
ConstantFolding(const std::string& name,
const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;

virtual bool AllowConstantFolding(const Node& node) const { return true; }
chilo-ms marked this conversation as resolved.
Show resolved Hide resolved

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

Expand Down
Loading
Loading