Skip to content

Implement batched gemm wmma (RDNA batched gemm) based on wmma cshuffle v3 #2319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from

Conversation

krithalith
Copy link
Contributor

Proposed changes

This MR implements batched gemm for wmma based on closely on the existing gemm wmma universal (cshuffle v3).

A new Device level struct was added called DeviceBatchedGemm_Wmma_CShuffleV3 which is very closely based on DeviceGemm_Wmma_CShuffleV3. Note that since batched gemms must inherit from the DeviceBatchedGemm base class, there is currently no support for some of the extra members that appear in the DeviceGemmV2 base class (which are not in the DeviceGemm base class). Effectively this means that k batching and permuteA/B are not supported right now. This could be resolved by making a new batched gemm baseclass with those extra features, but that will probably also require some changes in the instance factories and profiler. I believe these features are not required for now.

A new custom kernel kernel_batched_gemm_wmma_cshuffle_v3() was added which is closely based on kernel_gemm_wmma_cshuffle_v3(). To implement batching the kernel is simply called with an increased number of workgroups, by increasing the gridY dimensions from 1 to batch. The gridZ dimension could not be used since this is already used by the k batching calculations, but gridY was still completely unused.

Instances for the new operation were added to the instance factory which directly mirror those for DeviceGemm_Wmma_CShuffleV3. The instances support:

Datatypes: f16-f16-f16, bf16-bf16-bf16
Layouts: Row-Row-Row, Row-Column-Row, Column-Row-Row, Column-Column-Row
Padding: No padding, MNK padding
Pipelines: v1, v3

Gtest tests were added based on those for batched gemm xdl. They test all available datatypes and layouts. Tested on RDNA3 Radeon 7900XTX (gfx1100).

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@krithalith krithalith force-pushed the 2025_06_10-implement-batched-gemm-wmma-cshuffle-v3 branch from 44f21c8 to 6710edb Compare June 10, 2025 13:40
@krithalith krithalith requested a review from bartekxk June 10, 2025 13:45
kiefer added 13 commits June 18, 2025 08:30
…gemm in general to gfx11 and gfx12 categories, and split existing batched_gemm test into xdl and wmma versions. Updated profiler and instance factory. For now only adding f16-row-row-row-GemmDefault. For now actual device instance list is empty.
…leV3 and make sure it's used in the instance factory and tests. Currently the new batched device level struct cannot actually handle batching, but it does pass tests with a trivial batch size of 1, meaning that the overall structure is good.
…eV3. Batching arguments not passed to kernel yet.
…ffleV3. In principle the whole thing works now, just need to add other data types and perhaps do some cleanup.
… shapes. Some of the original test cases for batched gemm do not work based on cshuffle v3 because the dimensions are too small.
…main-k-block-loop, check compute type, packed buffer size calc. Ported new instance lists.
kiefer added 2 commits June 18, 2025 08:31
…ile_batched_gemm_impl() from test_batched_gemm_wmma to match latest definition of that function.
@krithalith krithalith force-pushed the 2025_06_10-implement-batched-gemm-wmma-cshuffle-v3 branch from 6710edb to d436ed1 Compare June 18, 2025 08:52
@krithalith
Copy link
Contributor Author

I just rebased on develop again, and since my fix to the argument order of profile_batched_gemm_impl() was merged, I had to update the argument order in the newly introduced test_batched_gemm_wmma one last time (verified normal performance).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants