Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling L2+ Optimizations for EPs #23517

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1d5ca89
init
chilo-ms Jan 21, 2025
e9119d5
include GraphTransformerManager to GetCapability
chilo-ms Jan 26, 2025
b7a0b79
Add GraphTransformerManager for EP, optimization function and Compute…
chilo-ms Jan 26, 2025
3b28ffc
refine GraphTransformerManager for EP, optimization function and Comp…
chilo-ms Jan 28, 2025
309341e
TRT EP creates optimization compute capability
chilo-ms Jan 28, 2025
d0cbc65
add comments
chilo-ms Jan 28, 2025
b239db0
remove unnecessary code
chilo-ms Jan 28, 2025
a83dd11
remove commented code
chilo-ms Jan 29, 2025
372342c
add a function to include DQ that is filtered out by TRT parser
chilo-ms Jan 29, 2025
39fa897
add standalone GraphOptimizerRegistry as singleton
chilo-ms Feb 3, 2025
627a00a
remove redundant code
chilo-ms Feb 3, 2025
06ca086
remove redundant code
chilo-ms Feb 3, 2025
a965ffb
remove redundant code
chilo-ms Feb 3, 2025
4c2697c
add back function
chilo-ms Feb 4, 2025
2b81789
changed code per reviewer
chilo-ms Feb 4, 2025
0c10cd4
don't create optimizer instances until EP requests it by calling GetO…
chilo-ms Feb 6, 2025
3360dfd
minor modification
chilo-ms Feb 6, 2025
5f7da9f
fix compiler error
chilo-ms Feb 7, 2025
e610bc8
remove unnecessary member function
chilo-ms Feb 7, 2025
e95f2c3
lintrunner -a
chilo-ms Feb 7, 2025
bad19b9
handle status
chilo-ms Feb 7, 2025
d4968cb
remove unnecessary code
chilo-ms Feb 7, 2025
df5aca9
add GetMutableMetaDef
chilo-ms Feb 9, 2025
60d9599
update TRT EP
chilo-ms Feb 9, 2025
3c46897
refactor the code per reviewer's suggestions
chilo-ms Feb 18, 2025
ee35614
remove unnecessary code
chilo-ms Feb 18, 2025
958706e
fix format
chilo-ms Feb 18, 2025
5ebb117
use session logger for optimization function
chilo-ms Feb 18, 2025
08b85f9
add ORT_UNUSED_PARAMETER for param
chilo-ms Feb 18, 2025
1e0ae2e
fix compiler warnings/errors
chilo-ms Feb 18, 2025
4ee99b6
fix compiler warning/error
chilo-ms Feb 18, 2025
718ab98
fix compiler warnings/errors
chilo-ms Feb 18, 2025
644b837
run ConstantFoldingDQ only when trt_dla_enable is true
chilo-ms Feb 20, 2025
a2bfa09
Merge branch 'main' into chi/ort_enable_l2_plus_opt_for_ep
chilo-ms Feb 21, 2025
5b8bb7b
fix compiler error when resolving the conflicts
chilo-ms Feb 21, 2025
46e09d3
fix compiler error when resolving the conflicts
chilo-ms Feb 21, 2025
56f0f52
lintrunner -a
chilo-ms Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/onnxruntime/core/graph/indexed_sub_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ struct IndexedSubGraph {
return meta_def_.get();
}

/** Gets the mutable meta definition needed to represent this subgraph as a FunctionProto.
@returns MetaDef instance if it has been set. nullptr if not. */
MetaDef* GetMutableMetaDef() {
return meta_def_.get();
}

// Check if the accounting is enabled for the current EP
bool IsAccountingEnabled() const {
return resource_accountant != nullptr &&
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/framework/compute_capability.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// Licensed under the MIT License.

#pragma once
#include <functional>
#include "core/common/common.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/graph/graph.h"

namespace onnxruntime {
// A structure encodes a subgraph and the method to run it.
Expand All @@ -21,5 +23,19 @@

ComputeCapability(std::unique_ptr<IndexedSubGraph> t_sub_graph)
: sub_graph(std::move(t_sub_graph)) {}

// Optional function to optimize this ComputeCapability.
// This will be called by ORT once the ComputeCapability is assigned to the EP
// Optimization: std::function<Status(const Graph& graph, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update)>
std::function<Status(Graph&, const ComputeCapability&, ComputeCapability&, const logging::Logger& logger)> optimization_func;

// Optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized.
// when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made.
// IndexedSubGraph.nodes:
// - update based on RemovedNode/AddNode calls
// IndexedSubGraph.MetaDef (if present):
// - inputs and outputs will be unchanged
// - constant_initializers MAY change if we constant fold an initializer during optimization
std::vector<std::unique_ptr<ComputeCapability>> nodes_to_optimize;

Check warning on line 39 in onnxruntime/core/framework/compute_capability.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/compute_capability.h:39: Add #include <vector> for vector<> [build/include_what_you_use] [4]
};
} // namespace onnxruntime
174 changes: 106 additions & 68 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,56 @@ static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
}

/**
* Check if a node can be placed on a specific provider.
* Do nothing if the node is already assigned
* Check whether the given IndexedSubGraph is available for assigning to a specific provider.
*
*/
static bool IsIndexedSubGraphAvailableForAssignment(Graph& graph,
const IndexedSubGraph& capability,
GraphPartitioner::Mode mode,
const std::string& provider_type) {
// The provider can run a single node in the <graph> if not using meta-defs.
if (capability.GetMetaDef() == nullptr && capability.nodes.size() == 1) {
auto* node = graph.GetNode(capability.nodes[0]);
if (nullptr != node && node->GetExecutionProviderType().empty()) {
// The node was not fused or assigned.
return true;
}
return false;
}

// if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned,
// so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by
// preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to
// and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes
// should never be on those lists.
//
// when the ORT format model is loaded we will process it normally with EP priority being applied for
// whichever EPs are enabled at the time.
//
// e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP.
// We want the ORT format model to be able to be run as efficiently as possible on either platform,
// so we want all the nodes that either may take to be preserved. If we did not do this we would
// need to create one ORT format model for Android and one for iOS.
if (mode == GraphPartitioner::Mode::kAssignOnly) {
return true;
}

for (auto node_index : capability.nodes) {
const auto* node = graph.GetNode(node_index);
if ((nullptr == node) ||
(!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) {
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
return false;
}
}

return true;
}

/**
* Return a fused node or assign the nodes in the indexed subgraph to the current EP.
*
* \param graph
* \param capability
* \param kernel_registry_mgr
Expand All @@ -298,75 +346,42 @@ static Node* PlaceNode(Graph& graph, const IndexedSubGraph& capability,
if (nullptr == capability.GetMetaDef()) {
TryAssignSingleNode(graph, capability, provider_type);
} else {
// The <provider> can run a fused <sub_graph> in the <graph>.
const bool acc_enabled = capability.IsAccountingEnabled();
if (mode == GraphPartitioner::Mode::kNormal) {
std::ostringstream oss;
oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++;
std::string node_name = oss.str();

// Check whether any node in the <sub_graph> was already assigned. If so it cannot be stolen as assignment is done
// in order of EP priority
bool sub_graph_available_for_assignment = true;
if (mode != GraphPartitioner::Mode::kAssignOnly) {
// if mode is kAssignOnly we want all nodes that can _potentially_ be taken by compiling EPs to be assigned,
// so that we aggregate the nodes covered and ensure the original nodes remain in the ORT format model by
// preventing level 2 and 3 optimizers from changing them. optimizers check the EP the node is assigned to
// and only make changes if the EP is on the optimizer's list of supported EPs. an EP that compiles nodes
// should never be on those lists.
//
// when the ORT format model is loaded we will process it normally with EP priority being applied for
// whichever EPs are enabled at the time.
//
// e.g. an Android NNAPI EP may take different/overlapping nodes to a iOS CoreML EP.
// We want the ORT format model to be able to be run as efficiently as possible on either platform,
// so we want all the nodes that either may take to be preserved. If we did not do this we would
// need to create one ORT format model for Android and one for iOS.
for (auto node_index : capability.nodes) {
const auto* node = graph.GetNode(node_index);
if ((nullptr == node) ||
(!node->GetExecutionProviderType().empty() && node->GetExecutionProviderType() != provider_type)) {
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
sub_graph_available_for_assignment = false;
break;
}
Node* fused_node = nullptr;
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
fused_node = &graph.FuseSubGraph(capability, node_name);
} else {
// create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed
// through to Compile via a filtered GraphViewer.
fused_node = &graph.BeginFuseSubGraph(capability, node_name);
}
}

if (sub_graph_available_for_assignment) {
const bool acc_enabled = capability.IsAccountingEnabled();
if (mode == GraphPartitioner::Mode::kNormal) {
std::ostringstream oss;
oss << provider_type << "_" << capability.GetMetaDef()->name << "_" << fused_node_unique_id++;
std::string node_name = oss.str();

Node* fused_node = nullptr;
if (fusion_style == IExecutionProvider::FusionStyle::Function) {
fused_node = &graph.FuseSubGraph(capability, node_name);
} else {
// create a fused node without copying everything to a Function body. The IndexedSubGraph will be passed
// through to Compile via a filtered GraphViewer.
fused_node = &graph.BeginFuseSubGraph(capability, node_name);
}

fused_node->SetExecutionProviderType(provider_type);
if (acc_enabled) {
// We account for the fused node. We operate under assumption
// that the fused node would use no more memory when the nodes we are fusing.
// and potentially less than that, and therefore, no threshold check is needed here.
// All threshold checks are done within the EP.
capability.ComputeAndAccountForNode(*fused_node);
}
fused_node->SetExecutionProviderType(provider_type);
if (acc_enabled) {
// We account for the fused node. We operate under assumption
// that the fused node would use no more memory when the nodes we are fusing.
// and potentially less than that, and therefore, no threshold check is needed here.
// All threshold checks are done within the EP.
capability.ComputeAndAccountForNode(*fused_node);
}

result = fused_node;
} else {
// assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them.
// This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion
// at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device
// capabilities.
for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) {
auto* node = graph.GetNode(capability.nodes[i]);
if (node != nullptr) {
node->SetExecutionProviderType(provider_type);
if (acc_enabled) {
capability.AccountForNode(i);
}
result = fused_node;
} else {
// assign the nodes in the indexed subgraph to the current EP so that level 2+ optimizers will not change them.
// This is used when exporting an ORT format model to maintain the original nodes and re-do the fusion
// at runtime. The original nodes provide a fallback if fewer nodes can be fused at runtime due to device
// capabilities.
for (size_t i = 0, limit = capability.nodes.size(); i < limit; ++i) {
auto* node = graph.GetNode(capability.nodes[i]);
if (node != nullptr) {
node->SetExecutionProviderType(provider_type);
if (acc_enabled) {
capability.AccountForNode(i);
}
}
}
Expand Down Expand Up @@ -450,7 +465,30 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
entry->sub_graph->GetMetaDef() != nullptr;
}));
for (auto& capability : capabilities) {
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
// The <provider> can run a fused <sub_graph> in the <graph>.
// Check whether any node in the <sub_graph> was already assigned. If so it cannot be stolen as assignment is done
// in order of EP priority
bool sub_graph_available_for_assignment = IsIndexedSubGraphAvailableForAssignment(graph, *capability->sub_graph, mode, type);

// If the <sub_graph> is available to be assigned to the EP and the ComputeCapability has nodes_to_optimize,
// run EP related optimizations and update ComputeCapability.
if (sub_graph_available_for_assignment && !capability->nodes_to_optimize.empty()) {
for (auto& optimization_cc : capability->nodes_to_optimize) {
if (optimization_cc->optimization_func) {
auto status = optimization_cc->optimization_func(graph, *optimization_cc, *capability, logger);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, "The optimization function failed to finish.");
}
// #TODO: Handle nested optimization ComputeCapability
}
}
}

Node* n = nullptr;
if (sub_graph_available_for_assignment) {
n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
}

if (n != nullptr) {
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) {
Expand Down
13 changes: 11 additions & 2 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers,
const InlinedHashSet<std::string>& excluded_initializers) noexcept
: GraphTransformer("ConstantFolding", compatible_execution_providers),
: ConstantFolding("ConstantFolding", execution_provider, skip_dequantize_linear, config_options, compatible_execution_providers, excluded_initializers) {
}

ConstantFolding::ConstantFolding(const std::string& name,
const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers,
const InlinedHashSet<std::string>& excluded_initializers) noexcept
: GraphTransformer(name, compatible_execution_providers),
skip_dequantize_linear_(skip_dequantize_linear),
config_options_(config_options),
excluded_initializers_(excluded_initializers),
Expand Down Expand Up @@ -144,7 +153,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,

for (NodeIndex i : order) {
auto* node = graph.GetNode(i);
if (!node) {
if (!node || !AllowConstantFolding(*node)) {
continue;
}

Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/optimizer/constant_folding.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ class ConstantFolding : public GraphTransformer {
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;

protected:
/**
* Same as the constructor above but with a name provided by derived class.
*/
ConstantFolding(const std::string& name,
const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;
/**
* Derived class can implement this virtual function to limit the nodes that can be constant folded.
*/
virtual bool AllowConstantFolding(const Node& node) const {
ORT_UNUSED_PARAMETER(node);
return true;
}

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/core/optimizer/graph_optimizer_registry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/optimizer/graph_optimizer_registry.h"
#include "core/optimizer/graph_transformer_utils.h"
#include "core/optimizer/selection_and_optimization_func.h"
#include "core/optimizer/qdq_transformer/constant_folding_dq_node.h"

using namespace onnxruntime;

Check warning on line 9 in onnxruntime/core/optimizer/graph_optimizer_registry.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/optimizer/graph_optimizer_registry.cc:9: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using namespace ::onnxruntime::common;

Check warning on line 10 in onnxruntime/core/optimizer/graph_optimizer_registry.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/optimizer/graph_optimizer_registry.cc:10: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {

GraphOptimizerRegistry::GraphOptimizerRegistry() {}

std::optional<SelectionFunc> GraphOptimizerRegistry::GetSelectionFunc(std::string& name) const {

Check warning on line 16 in onnxruntime/core/optimizer/graph_optimizer_registry.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/graph_optimizer_registry.cc:16: Add #include <string> for string [build/include_what_you_use] [4]
auto lookup = transformer_name_to_selection_func_.find(name);
if (lookup != transformer_name_to_selection_func_.end()) {
return transformer_name_to_selection_func_.at(name);
}
LOGS(*logger_, WARNING) << "Can't find selection function of " << name;
return std::nullopt;
}

common::Status GraphOptimizerRegistry::Create(
const onnxruntime::SessionOptions* sess_options,
const onnxruntime::IExecutionProvider* cpu_ep,
const logging::Logger* logger) {
session_options_ = sess_options;

Check warning

Code scanning / CodeQL

Local variable address stored in non-local memory Warning

A stack address which arrived via a
parameter
may be assigned to a non-local variable.
cpu_ep_ = cpu_ep;
logger_ = logger;

// Add predefined transformer names and their selection functions
transformer_name_to_selection_func_[kConstantFoldingDQ] = ConstantFoldingDQFuncs::Select;

return Status::OK();
}

// Initialize static members
std::shared_ptr<GraphOptimizerRegistry> onnxruntime::GraphOptimizerRegistry::graph_optimizer_registry = nullptr;

Check warning on line 40 in onnxruntime/core/optimizer/graph_optimizer_registry.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/graph_optimizer_registry.cc:40: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]
std::mutex GraphOptimizerRegistry::registry_mutex;
} // namespace onnxruntime
Loading
Loading