Skip to content

[MLAS] Add an NHWC implementation of convolution to avoid transposes#26834

Merged
hariharans29 merged 180 commits into
microsoft:mainfrom
orlmon01:main
May 14, 2026
Merged

[MLAS] Add an NHWC implementation of convolution to avoid transposes#26834
hariharans29 merged 180 commits into
microsoft:mainfrom
orlmon01:main

Conversation

@orlmon01
Copy link
Copy Markdown
Contributor

  • Modification to the CPU EP to specify channels_last when data format is NWHC
  • Added a FusedNhwcConv kernel
  • Implementation of the kernel in mlas
  • Added compiler guards so it is inly used with KleidiAi (for now, can be removed if needed)
  • Added unittests

Description

Currently OnnxRT supports NCHW as a default datalayout. For optimisations and kernels that operate better in NHWC layout, or where the datalayout is NHWC in the first place Transposes are added around the layers. This patch seeks to eliminate them in cases of convolutions where it would cause a performance decrease.

Motivation and Context

KleidiAi specific implementation of this feature. Only supports convolutions, DepthWise to follow. Currently a little strict with the filters as a result.

…transposes

* Modification to the CPU EP to specify channels_last when data format is NWHC
* Added a FusedNhwcConv kernel
* Implementation of the kernel in mlas
* Added compiler guards so it is inly used with KleidiAi (for now, can be removed if needed)
* Added unittests

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
@orlmon01
Copy link
Copy Markdown
Contributor Author

@microsoft-github-policy-service agree company="Arm"

@orlmon01 orlmon01 marked this pull request as draft December 19, 2025 12:34
@orlmon01 orlmon01 marked this pull request as ready for review December 19, 2025 12:35
@orlmon01
Copy link
Copy Markdown
Contributor Author

Feedback appreciated as this PR makes quite a lot of changes to the codebase well outside of the normal KleidiAI scope.

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
@Rohanjames1997
Copy link
Copy Markdown
Contributor

Hi @orlmon01, I imagine that avoiding transposes also improves performance.
Do you have any performance results to share?
TIA!

@orlmon01
Copy link
Copy Markdown
Contributor Author

orlmon01 commented Jan 6, 2026

Hi @orlmon01, I imagine that avoiding transposes also improves performance. Do you have any performance results to share? TIA!

Hiya,
Sorry, just back after the holidays. Yes, there is a performance increase. It depends on the model. Ones with multiple consecutive Convolutions where transposes can be eliminated will see a larger speedup. Even with the limited range of convolutions it's implemented for there should still be a performance increase in most cases.

I have some numbers somewhere from a Mobilenet model I was using for testing that I'll add in a bit, once I find / regenerate them. :)

@orlmon01
Copy link
Copy Markdown
Contributor Author

orlmon01 commented Jan 6, 2026

mobilenet model without the current patch:

Setting intra_op_num_threads to 1
Overriding dimension with name, N, to 1
Overriding dimension with name, T, to 1000
Overriding dimension with name, cache_T_attn, to 32
Overriding dimension with name, right_context, to 5
Session creation time cost: 0.020627 s
First inference time cost: 10 ms
Total inference time cost: 1.33257 s
Total inference requests: 200
Average inference time cost total: 6.662851 ms
Total inference run time: 1.33266 s
Number of inferences per second: 150.075
Avg CPU usage: 16 %
Peak working set size: 85429583872 bytes
Avg CPU usage:16
Peak working set size:85429583872
Runs:200
Min Latency: 0.006193 s
Max Latency: 0.007625 s
P50 Latency: 0.00666992 s
P90 Latency: 0.00686983 s
P95 Latency: 0.00694425 s
P99 Latency: 0.00733196 s
P999 Latency: 0.007625 s

Same model with changes:

Setting intra_op_num_threads to 1
Overriding dimension with name, N, to 1
Overriding dimension with name, T, to 1000
Overriding dimension with name, cache_T_attn, to 32
Overriding dimension with name, right_context, to 5
Session creation time cost: 0.0217724 s
First inference time cost: 7 ms
Total inference time cost: 1.12897 s
Total inference requests: 200
Average inference time cost total: 5.644857 ms
Total inference run time: 1.12905 s
Number of inferences per second: 177.14
Avg CPU usage: 16 %
Peak working set size: 80362864640 bytes
Avg CPU usage:16
Peak working set size:80362864640
Runs:200
Min Latency: 0.00527438 s
Max Latency: 0.006706 s
P50 Latency: 0.00566529 s
P90 Latency: 0.00579958 s
P95 Latency: 0.00585429 s
P99 Latency: 0.00639058 s
P999 Latency: 0.006706 s

@edgchen1
Copy link
Copy Markdown
Contributor

edgchen1 commented Jan 8, 2026

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Update to the internal_testings_tests helper macros for file expansion so it works on other platforms
Fix for failing ConvDepthwiseFloat test, allows for a small tolerance when running on different hardware
For for failing TestSaveAndLoadOrtModel test
Make sure the model being saved / loaded is being done from a writeable location
Fix for undeclared identifier linker error
@hariharans29
Copy link
Copy Markdown
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
@hariharans29
Copy link
Copy Markdown
Member

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an NHWC (channels-last) implementation of convolution operations to avoid costly transpose operations in the CPU execution provider. The implementation includes KleidiAI-specific optimizations and a fallback path for NHWC convolutions.

Changes:

  • Added NhwcFusedConv kernel for float32 convolutions in NHWC layout (KleidiAI-guarded)
  • Implemented NHWC fast path and fallback path with explicit NHWC↔NCHW conversions in MLAS
  • Extended test infrastructure to resolve paths dynamically and filter NHWC transformers in existing tests

Reviewed changes

Copilot reviewed 24 out of 24 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
onnxruntime/core/providers/cpu/nn/conv.h Added channels_last_ flag to Conv kernel
onnxruntime/core/providers/cpu/nn/conv.cc Implemented NHWC convolution logic with fast path and fallback
onnxruntime/core/optimizer/nhwc_transformer.cc Added KleidiAI filter and FusedConv sum input handling
onnxruntime/core/mlas/lib/convolve.cpp Added ChannelsLast parameter to MlasConvPrepare
onnxruntime/core/mlas/inc/mlas.h Added ChannelsLast field to MLAS_CONV_PARAMETERS
onnxruntime/contrib_ops/cpu/fused_conv.cc Registered NhwcFusedConv kernel
onnxruntime/test/optimizer/nhwc_transformer_test.cc Added depthwise convolution test case
onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc Filtered NhwcTransformer in tests
onnxruntime/test/optimizer/conv_add_act_test.cc Updated to handle both FusedConv variants
onnxruntime/test/internal_testing_ep/internal_testing_tests.cc Added path resolution utilities
onnxruntime/test/framework/ort_model_only_test.cc Added path resolution with diagnostic output
onnxruntime/core/util/math_cpu.cc Added Im2col instantiation for float NHWC
onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp Updated to support channels-last input

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/optimizer/nhwc_transformer.cc Outdated
Comment thread onnxruntime/core/framework/kernel_type_str_resolver.cc Outdated
Comment thread onnxruntime/test/framework/ort_model_only_test.cc Outdated
Comment thread onnxruntime/core/optimizer/nhwc_transformer.cc Outdated
Comment thread onnxruntime/test/internal_testing_ep/internal_testing_tests.cc
Comment thread onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Comment thread onnxruntime/core/providers/cpu/nn/conv.cc Outdated
@hariharans29 hariharans29 changed the title Add an implementation an NHWC implementation of convolution to avoid transposes [MLAS] Add an implementation an NHWC implementation of convolution to avoid transposes Jan 21, 2026
orlmon01 added 27 commits May 12, 2026 10:23
* NhwcFusedConv should now be available in minimal builds as the resolver bytes contain it

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
* RHS hashing now uses the full tensor to ensure uniqueness
* LHS no longer uses hashing as it's unnecessary

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
* The change replaces the unsafe long-lived thread_local RHS cache with a kernel-owned packed-filter cache for NHWC float conv when the filter and optional bias are constant initializers. The packed RHS is built once per kernel instance and then reused safely for that kernel’s lifetime, which avoids pointer/hash aliasing issues without pulling in the larger ORT prepack machinery.

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
…mationRequiredOps test

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
…solverCanResolveNhwcFusedConv and have it check load time instead of the saved model

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
…r needed

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
…d tests

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
…transposes

* Modification to the CPU EP to specify channels_last when data format is NWHC
* Added a FusedNhwcConv kernel
* Implementation of the kernel in mlas
* Added compiler guards so it is inly used with KleidiAi (for now, can be removed if needed)
* Added unittests
* Rebased
* Docs regenerated

Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
Signed-off-by: Orlaith Monahan <orlaith.monahan@arm.com>
@hariharans29 hariharans29 merged commit a8ba94a into microsoft:main May 14, 2026
87 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants