Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate MmaOp patterns properly on Hopper #4072

Open
wants to merge 99 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Mar 13, 2025

#3986 fixes our most common use cases of MatmulOp and LinearOp translation on Hopper. It does so by scheduling global intermediates' mtypes and allocation domains during translation. However, in case there is no translation and we are already given an MmaOp this fails. The current PR instead does mtype and allocation domain propagation while caching operands, so that we can properly set as_ and bs_ and so forth. This means that the input fusions don't need to differ between hopper and ampere anymore, so we can translate both cases in the same way and the only difference will be during scheduling.

Note that this will also make it easier to maintain internal tooling which uses things like canonicalizeInputToBMNK.

Changes in this PR:

  • Remove avoid_intermediates argument to MatmulPattern::translateToMmaOp and update all call sites.
  • Remove some helper utilities in mma_utils.cpp
  • Introduce scheduler_utils::scheduleInputToSkipIntermediates which will schedule the allocation domains and memory types of consumers of inputs recursively to avoid "metadata ops" at the beginning of a fusion.
  • Rearrange HopperMultipleMatmulScheduler to remove defineOperandCaches and move cacheInputsAndOutputs after pattern translation but before findRoles. Also cacheInputsAndOutputs now uses scheduler_utils::scheduleInputToSkipIntermediates and defines the operand roles as the last gmem tensor returned by that utility.
  • Unguards AllocationDomainTest.BasicMatmul/* to allow it to run on Hopper

TODO: Squeeze, permute, chain tests
Failing at computing TMA descriptor now
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Mar 18, 2025

Currently some tests fail to compile because I am calling scheduler_utils::scheduleInputToSkipIntermediates even on Ampere. This should work, but it exposes some bugs in the current tensor producer alias system. A better system would not simply compute the index based on the skipped tensor but would allow us to do src indexing on tensors that are not direct producers instead. I think this is something that TensorIndexer could support, but I could not figure out how to do it currently without modifying TensorIndexer. So for now, I plan to skip this call in MultipleMatmulScheduler::cacheInputsAndOutputs upon request (i.e. on Ampere).

EDIT: see the bool skip_intermediates argument to cacheInputsAndOutputs.

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@@ -467,19 +467,19 @@ void AmpereMultipleMatmulScheduler::validate() const {
}

void AmpereMultipleMatmulScheduler::run() {
// Clears memory spaces on intermediate tensors, calls
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to Ampere scheduler should not change generated code, but do let us use a common cacheInputsAndOutputs method.

Comment on lines +1098 to +1099
for (Val* dv : fusion_->outputs()) {
auto* d = dv->as<TensorView>();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dc is ignored anyway. We long ago stopped using cached_outputs_ in the Hopper scheduler, so I removed it for Ampere too as it was causing a problem due to the refactor not filling that vector.

@@ -1789,177 +1789,6 @@ std::string MatmulPattern::toString() const {

namespace {

// Check whether tv has all the output_groups in its logical domain, and
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These utilities are no longer needed since we can now safely translate all matmul patterns the same way on both Hopper and Ampere. The differences are purely in downstream scheduling.

@@ -2492,7 +2492,7 @@ class MatmulSchedulerPluginTest : public NVFuserTest {

// Test that our fake plugin works to override the default heuristic
TEST_F(MatmulSchedulerPluginTest, BasicMatmul) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 9, 0);
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More tests can likely be unguarded once we update their params or set a fixture to create default sets of parameters.

@jacobhinkle jacobhinkle changed the title [WIP] Translate MmaOp patterns properly on Hopper Translate MmaOp patterns properly on Hopper Mar 19, 2025
Fixes the horizontal fusion tests
This can be made to work but currently the config factory generates an
invalid config.
@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle
Copy link
Collaborator Author

!test --diff

@jacobhinkle jacobhinkle marked this pull request as ready for review March 25, 2025 15:42
Comment on lines 129 to 145
void HopperMultipleMatmulScheduler::run() {
// Clears memory spaces on intermediate tensors, calls
// cache{After,Before,Fork} on inputs and outputs
cacheInputsAndOutputs();

// Finds matmul patterns and translates them to MmaOps, then finds tensor
// and dimension roles for all tensors in the fusion
findPatterns();
translatePatterns();
findRoles();

// Clears memory spaces on intermediate tensors, calls
// cache{After,Before,Fork} on inputs and outputs.
// Defines acw_smem/bcw_smem and acr/bcr by possibly calling cacheAfter.
// This also collects mma_results_
defineOperandCaches();
cacheInputsAndOutputs(/*skip_intermediates=*/true);

// We wait until we are done caching tensors to find roles, since this
// requires building an IdModel, which would not be updated during the cache
// calls.
findRoles();

inspectPrologues();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rearranged to not cache until after translation of patterns. This is helpful because it lets us cache the global tensors that have been "skipped" with producer tensor aliases, instead of the original fusion inputs.

@jacobhinkle jacobhinkle requested a review from rdspring1 March 25, 2025 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants