Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add silu and bias epilogue matmul tests #4095

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Mar 18, 2025

This PR adds the tests from #4069.

  • Inputs are already broadcasted.
  • Ideal input size of (m = 8192, n = 8192, k = 8192)
  • Created EpilogueBiasPersistentBroadcastInputs AND EpilogueSiluPersistentBroadcastInputs tests with known best MatmulParams
  • Created FwdEpilogueBiasFusion
  • Renamed FwdEpilogueFusion to FwdEpilogueSiluFusion

Copy link

github-actions bot commented Mar 18, 2025

Review updated until commit bc7af0f

Description

  • Added tests for bias and silu epilogue matmul

  • Created new test cases with specific MatmulParams

  • Renamed FwdEpilogueFusion to FwdEpilogueBiasFusion


Changes walkthrough 📝

Relevant files
Tests
test_matmul.cpp
Add bias and silu epilogue matmul tests                                   

tests/cpp/test_matmul.cpp

  • Added FwdEpilogueBiasFusion test case
  • Added FwdEpilogueSiluFusion test case
  • Added EpilogueBiasPersistentBroadcastInputs test case
  • Added EpilogueSiluPersistentBroadcastInputs test case
  • +189/-1 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Tolerance Settings

    The tolerance settings for at::allclose are set to 5e-2 and 1e-1 in some tests, which might be too high for certain applications. It's important to ensure that these tolerances are justified and do not hide potential issues.

      EXPECT_TRUE(
          at::allclose(cg_outputs[0].as<at::Tensor>(), tv3_ref, 5e-2, 5e-2));
    }
    Reference Calculation

    The reference calculation for EpilogueSiluPersistentBroadcastInputs uses at::linear followed by manual computation of the SiLU activation. It would be beneficial to verify that this manual computation matches the expected behavior of the SiLU activation function.

    auto tv3_ref = at::linear(t0.squeeze(), t1.squeeze());
    auto tv4_ref = tv3_ref.to(at::kFloat);
    auto tv11_ref =
        (tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * t2).to(at::kBFloat16);
    
    Performance Metrics

    The PR does not provide performance metrics or a comparison with existing implementations. It would be helpful to include performance data to demonstrate the effectiveness of the new tests and any performance improvements.

    TEST_P(MLPBenchmarkTest, FwdEpilogueBiasFusion) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 4096, N = 14336, K = 5120;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // N, K
      auto tv2 = makeContigConcreteTensor({-1, -1}, dtype); // M, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv3 = linear(tv0, tv1, tv2);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto t0 = at::randn({M, K}, options);
      auto t1 = at::randn({N, K}, options);
      auto t2 = at::randn({M, N}, options);
      auto tv3_ref = at::linear(t0, t1, t2);
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelArgumentHolder inputs = {t0, t1, t2};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(
          at::allclose(cg_outputs[0].as<at::Tensor>(), tv3_ref, 5e-2, 5e-2));
    }
    
    TEST_P(MLPBenchmarkTest, FwdEpilogueSiluFusion) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 4096, N = 14336, K = 5120;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // N, K
      auto tv2 = makeContigConcreteTensor({-1, -1}, dtype); // M, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv3 = linear(tv0, tv1);
      fusion.addOutput(tv3);
    
      auto tv4 = castOp(DataType::Float, tv3);
      auto tv5 = neg(tv4);
      auto tv6 = exp(tv5);
      auto tv7 = add(fusion.oneVal(DataType::Float), tv6);
      auto tv8 = reciprocal(tv7);
      auto tv9 = mul(tv4, tv8);
      auto tv10 = mul(tv9, tv2);
      auto tv11 = castOp(DataType::BFloat16, tv10);
      fusion.addOutput(tv11);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto t0 = at::randn({M, K}, options);
      auto t1 = at::randn({N, K}, options);
      auto t2 = at::randn({M, N}, options);
    
      auto tv3_ref = at::linear(t0, t1);
      auto tv4_ref = tv3_ref.to(at::kFloat);
      auto tv11_ref =
          (tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * t2).to(at::kBFloat16);
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelArgumentHolder inputs = {t0, t1, t2};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(at::allclose(
          cg_outputs[0].as<at::Tensor>(), tv3_ref, 1e-6 * K, 1e-6 * K));
      EXPECT_TRUE(
          at::allclose(cg_outputs[1].as<at::Tensor>(), tv11_ref, 1e-2, 1e-2));
    }
    
    TEST_P(MLPBenchmarkTest, FwdEpilogueFusion_BroadcastInputs) {
      GTEST_SKIP() << "THIS TEST IS CURRENTLY FAILING" << std::endl;
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 4096, N = 14336, K = 5120;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, 1, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
      auto tv2 = makeContigConcreteTensor({1, -1}, dtype); // M, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv4 = fusedMultiplySum(tv0, tv1, {2});
      auto tv3 = castOp(dtype, tv4);
      fusion.addOutput(tv3);
    
      auto tv5 = neg(tv4);
      auto tv6 = exp(tv5);
      auto tv7 = add(fusion.oneVal(DataType::Float), tv6);
      auto tv8 = reciprocal(tv7);
      auto tv9 = mul(tv4, tv8);
      auto tv10 = mul(tv9, tv2);
      auto tv11 = castOp(dtype, tv10);
      fusion.addOutput(tv11);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto t0 = at::randn({M, 1, K}, options);
      auto t1 = at::randn({1, N, K}, options);
      auto t2 = at::randn({M, N}, options);
    
      auto tv3_ref = at::linear(t0.squeeze(), t1.squeeze());
      auto tv4_ref = tv3_ref.to(at::kFloat);
      auto tv11_ref =
          (tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * t2).to(at::kBFloat16);
    
      std::vector<c10::IValue> inputs = {t0, t1, t2};
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0, t1, t2});
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run({t0, t1, t2});
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      NVF_CHECK(at::allclose(
          cg_outputs[0].as<at::Tensor>(), tv3_ref, 1e-6 * K, 1e-6 * K));
      NVF_CHECK(at::allclose(cg_outputs[1].as<at::Tensor>(), tv11_ref, 1e-2, 1e-2));
    }
    
    TEST_P(MLPBenchmarkTest, FwdHorizontalFusion) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);
    
      EnableOptionsGuard eog;
      EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMultipleMatmuls);
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 4096, N = 14336, K = 5120;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // N, K
      auto tv2 = makeContigConcreteTensor({-1, -1}, dtype); // N, K
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv3 = linear(tv0, tv1);
      fusion.addOutput(tv3);
    
      auto tv4 = castOp(DataType::Float, tv3);
      auto tv5 = neg(tv4);
      auto tv6 = exp(tv5);
      auto tv7 = add(fusion.oneVal(DataType::Float), tv6);
      auto tv8 = reciprocal(tv7);
      auto tv9 = mul(tv4, tv8);
    
      auto tv10 = linear(tv0, tv2);
      fusion.addOutput(tv10);
    
      auto tv11 = mul(tv9, tv10);
      auto tv12 = castOp(DataType::BFloat16, tv11);
      fusion.addOutput(tv12);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto t0 = at::randn({M, K}, options);
      auto t1 = at::randn({N, K}, options);
      auto t2 = at::randn({N, K}, options);
    
      auto tv3_ref = at::linear(t0, t1);
      auto tv4_ref = tv3_ref.to(at::kFloat);
      auto tv10_ref = at::linear(t0, t2);
      auto tv12_ref =
          (tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * tv10_ref.to(at::kFloat))
              .to(at::kBFloat16);
    
      KernelArgumentHolder inputs = {t0, t1, t2};
    
      // Adjust parameters in order to fit smem and register constraints
      mparams.tile_sizes.cta_tile = GemmTile(128, 128, 64);
      mparams.tile_sizes.warp_tile = GemmTile(64, 128, 64);
      mparams.mma_macro = MmaMacro::Hopper_64_128_16;
      mparams.promote_prologue_smem_reuse = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      // TODO: Some of these are failing, perhaps due to improper syncing of
      // horizontally fused kernels?
      EXPECT_TRUE(at::allclose(
          cg_outputs[0].as<at::Tensor>(), tv3_ref, 1e-6 * K, 1e-6 * K));
      EXPECT_TRUE(at::allclose(
          cg_outputs[1].as<at::Tensor>(), tv10_ref, 1e-6 * K, 1e-6 * K));
      EXPECT_TRUE(
          at::allclose(cg_outputs[2].as<at::Tensor>(), tv12_ref, 5e-2, 1e-1));
    }
    
    TEST_P(MLPBenchmarkTest, FwdHorizontalFusion_BroadcastInputs) {
      // TODO: This test currently fails on Ampere
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);
    
      EnableOptionsGuard eog;
      EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMultipleMatmuls);
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 4096, N = 14336, K = 5120;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, 1, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
      auto tv2 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv4 = fusedMultiplySum(tv0, tv1, {2});
      auto tv3 = castOp(dtype, tv4);
      fusion.addOutput(tv3);
    
      auto tv5 = neg(tv4);
      auto tv6 = exp(tv5);
      auto tv7 = add(fusion.oneVal(DataType::Float), tv6);
      auto tv8 = reciprocal(tv7);
      auto tv9 = mul(tv4, tv8);
    
      auto tv10 = fusedMultiplySum(tv0, tv2, {2});
      auto tv10c = castOp(dtype, tv10);
      fusion.addOutput(tv10c);
    
      auto tv11 = mul(tv9, tv10);
      auto tv12 = castOp(DataType::BFloat16, tv11);
      fusion.addOutput(tv12);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto t0 = at::randn({M, 1, K}, options);
      auto t1 = at::randn({1, N, K}, options);
      auto t2 = at::randn({1, N, K}, options);
    
      auto tv3_ref = at::linear(t0.squeeze(), t1.squeeze());
      auto tv4_ref = tv3_ref.to(at::kFloat);
      auto tv10_ref = at::linear(t0.squeeze(), t2.squeeze());
      auto tv12_ref =
          (tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * tv10_ref.to(at::kFloat))
              .to(at::kBFloat16);
    
      std::vector<c10::IValue> inputs{t0, t1, t2};
    
      // Adjust parameters in order to fit smem and register constraints
      mparams.tile_sizes.cta_tile = GemmTile(128, 128, 64);
      mparams.tile_sizes.warp_tile = GemmTile(64, 128, 64);
      mparams.mma_macro = MmaMacro::Hopper_64_128_16;
      mparams.promote_prologue_smem_reuse = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0, t1, t2});
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run({t0, t1, t2});
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      NVF_CHECK(at::allclose(
          cg_outputs[0].as<at::Tensor>(), tv3_ref, 1e-6 * K, 1e-6 * K));
      NVF_CHECK(at::allclose(
          cg_outputs[1].as<at::Tensor>(), tv10_ref, 1e-6 * K, 1e-6 * K));
      NVF_CHECK(at::allclose(cg_outputs[2].as<at::Tensor>(), tv12_ref, 5e-2, 1e-1));
    }
    
    INSTANTIATE_TEST_SUITE_P(
        ,
        MLPBenchmarkTest,
        ::testing::Values(
            MLPBenchmarkTestParams{
                .warp_specialization = false,
                .persistent_kernel = false},
            MLPBenchmarkTestParams{
                .warp_specialization = true,
                .persistent_kernel = false},
            MLPBenchmarkTestParams{
                .warp_specialization = false,
                .persistent_kernel = true},
            MLPBenchmarkTestParams{
                .warp_specialization = true,
                .persistent_kernel = true}),
        [](const testing::TestParamInfo<MLPBenchmarkTestParams>& info) {
          std::stringstream ss;
          ss << (info.param.persistent_kernel ? "persistent" : "dataparallel");
          ss << (info.param.warp_specialization ? "_warpspec" : "_non_warpspec");
          return ss.str();
        });
    
    // This tests that we can use a small instruction tile with a medium size
    // warpgroup tile and a large CTA tile.
    TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
      auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {0});
    
      // Reorder the accumulator as [M, N, K]
      // [K, M, N] -> [M, N, K]
      tv2->reorder({{-3, -1}});
      tv2->commitLeafToLogical();
    
      auto tv3 = castOp(DataType::Half, tv2);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto t0 = at::randn({K, M, 1}, options);
      auto t1 = at::randn({K, 1, N}, options);
      auto out_ref = at::matmul(t0.squeeze().t(), t1.squeeze()).to(at::kHalf);
    
      MatMulTileOptions gemm_tile;
      // Regardless of the instruction, this should result in 2 warp groups i.e. 256
      // threads
      gemm_tile.cta_tile = GemmTile(256, 256, 32);
      gemm_tile.warp_tile = GemmTile(128, 128, 32);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_64_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      // NOTE: disabling smem use for this test since we currrently hit a bank
      // conflict.
      // TODO: enable smem epilogue once stmatrix is updated
      mparams.use_smem_epilogue = false;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = false;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0, t1});
      kir::Kernel* kernel = ke.compiledKernel()->kernel();
      ASSERT_TRUE(kernel != nullptr);
      EXPECT_TRUE(getBankConflictInfo(kernel).empty());
      EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));
    
      auto cg_outputs = ke.run({t0, t1});
    
      // Check number of launched threads matches what we expect
      EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128);
      EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4)
          << " expected 4 warp groups (BIDy==4) but found BIDy=="
          << ke.lastLaunchParams().bdimy();
    
      // Relax tolerance for larger sum due to large K
      NVF_CHECK(at::allclose(
          cg_outputs[0].as<at::Tensor>(), out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N
      // Note tv1 has allocation domain
      // tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true);
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = matmul(tv0, tv1);
    
      fusion.addOutput(tv2);
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
      auto t0 = at::randn({M, K}, options);
      // auto t1 = at::randn({N, K}, options).t();
      auto t1 = at::randn({K, N}, options);
      auto out_ref = at::matmul(t0, t1);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 16);
      gemm_tile.warp_tile = GemmTile(64, 64, 16);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_64_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {1, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0, t1});
      kir::Kernel* kernel = ke.compiledKernel()->kernel();
      ASSERT_TRUE(kernel != nullptr);
      EXPECT_TRUE(getBankConflictInfo(kernel).empty());
      EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));
    
      auto cg_outputs = ke.run({t0, t1});
    
      // Relax tolerance for larger sum due to large K
      NVF_CHECK(at::allclose(
          cg_outputs[0].as<at::Tensor>(), out_ref, 1e-6 * K, 1e-6 * K));
    }
    
    // Test that we can compile matmul kernels with both 32-bit and 64-bit indexing,
    // and that if we pass arguments for which this is unsafe (meaning there is
    // overflow), that the appropriate exception is raised
    TEST_F(HopperMatmulTest, IndexTypeValidation) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {1});
    
      // Reorder the accumulator as [M, N, K]
      // [M, K, N] -> [M, N, K]
      tv2->reorder({{-2, -1}});
      tv2->commitLeafToLogical();
    
      auto tv3 = castOp(DataType::Half, tv2);
      fusion.addOutput(tv3);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 64);
      gemm_tile.warp_tile = GemmTile(64, 256, 64);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = false;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 1;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {1, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      constexpr int64_t M = 1 << 17, N = 256, K = 1 << 17;
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
    
      { // This scope is to help us reclaim memory later
        auto a_ref = at::randn({M, K, 1}, options);
        auto b_ref = at::randn({1, K, N}, options);
        auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf);
        const std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
        mparams.cparams.index_type = DataType::Int32;
    
        at::Tensor int32_output;
        {
          Fusion fusion_clone;
          Fusion::copy(&fusion, &fusion_clone);
          SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
              ->schedule(&fusion_clone, &mparams);
    
          KernelExecutor ke;
          ke.compile(&fusion_clone, inputs);
          int32_output = ke.run(inputs)[0].as<at::Tensor>();
        }
    
        mparams.cparams.index_type = DataType::Int;
    
        Fusion fusion_clone;
        Fusion::copy(&fusion, &fusion_clone);
        SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
            ->schedule(&fusion_clone, &mparams);
    
        KernelExecutor ke;
        ke.compile(&fusion_clone, inputs);
        auto int64_output = ke.run(inputs)[0].as<at::Tensor>();
        EXPECT_TRUE(int64_output.equal(int32_output));
      }
    
      // Test that passing inputs that are too large in one dimension lead to error
      maybeClearAllocator(/*max_bytes=*/0);
      {
        mparams.cparams.index_type = DataType::Int;
    
        Fusion fusion_clone;
        Fusion::copy(&fusion, &fusion_clone);
        SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
            ->schedule(&fusion_clone, &mparams);
    
        constexpr int64_t M_big = 1L << 32, N_big = 2, K_big = 2;
        auto a_big = at::randn({M_big, K_big, 1}, options);
        auto b_big = at::randn({1, K_big, N_big}, options);
        const std::vector<c10::IValue> inputs_big{a_big, b_big};
    
        KernelExecutor ke;
        ke.compile(&fusion_clone, inputs_big);
        EXPECT_THAT(
            [&]() { ke.run(inputs_big); },
            ::testing::ThrowsMessage<nvfuser::nvfError>(
                ::testing::HasSubstr("Found unsafe casts from DataType::Index")));
      }
    }
    
    TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_BroadcastOp) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 2048, N = 2048, K = 8192;
      constexpr auto macro = MmaMacro::Hopper_64_256_16;
      constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N]
      constexpr auto swizzle = MmaInputSmemSwizzle::B128;
      const auto dtype = DataType::Half;
    
      constexpr bool use_smem_epilogue = false;
      constexpr bool use_warp_specialization = false;
    
      constexpr int64_t stages = 4;
      constexpr int64_t prefetch = 3;
      const int64_t cta_m = 2 * getM(macro);
      const int64_t cta_n = 1 * getN(macro);
    
      constexpr std::tuple<int64_t, int64_t, int64_t> cluster_dims{2, 1, 1};
    
      // auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype);
      // auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype);
      auto tv0 = makeContigConcreteTensor({-1, -1}, dtype);
      auto tv1 = makeContigConcreteTensor({-1, -1}, dtype);
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv0b = broadcast(tv0, {false, false, true});
      auto tv1b = broadcast(tv1, {false, true, false});
    
      auto tv2 = fusedMultiplySum(tv0b, tv1b, {0});
    
      // Reorder the accumulator as [M, N, K]
      // [K, M, N] -> [M, N, K]
      tv2->reorder({{-3, -1}});
      tv2->commitLeafToLogical();
    
      auto tv3 = castOp(DataType::Half, tv2);
      fusion.addOutput(tv3);
    
      if constexpr (
          cluster_dims != std::tuple<int64_t, int64_t, int64_t>{1, 1, 1}) {
        fusion.manage("cluster_dims", cluster_dims);
      }
    
      auto mma_ops = ir_utils::getOpsOfType<MmaOp>(&fusion);
      NVF_CHECK(
          1 == mma_ops.size(),
          "Invalid number of MmaOp instances in fusion definition, expected 1, got ",
          mma_ops.size());
      mma_ops.front()->setMacro(macro);
    
      // gmem [K, M, 1] x gmem [K, 1, N] -mma-> register [M, N, rK]
      // register [M, N, rK] -cast-> gmem [M, N]
    
      auto tv0c = tv0b->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      tv0c->setMemoryType(MemoryType::Shared);
      auto tv1c = tv1b->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      tv1c->setMemoryType(MemoryType::Shared);
    
      tv0b->setMemoryType(MemoryType::Global);
      tv1b->setMemoryType(MemoryType::Global);
    
      TensorView *tv3c = nullptr, *tv3_shmem = nullptr;
      if (use_smem_epilogue) {
        tv3_shmem = tv3->cacheBefore();
        tv3c = tv3_shmem->cacheBefore();
        tv3_shmem->setMemoryType(MemoryType::Shared);
        tv3c->setMemoryType(MemoryType::Local);
        tv3_shmem->definition()->as<LoadStoreOp>()->setOpType(
            LoadStoreOpType::StMatrix);
        tv3->definition()->as<LoadStoreOp>()->setOpType(
            LoadStoreOpType::CpAsyncBulkTensorTile);
      } else {
        tv3c = tv3->cacheBefore();
        tv3c->setMemoryType(MemoryType::Local);
      }
    
      // gmem [K, M, 1] -TMA-> smem [K, M, 1]
      // gmem [K, 1, N] -TMA-> smem [K, 1, N]
      // smem [K, M, 1] x smem [K, 1, N] -mma-> register [M, N, rK]
      // register [M, N, rK] -cast-> register [M, N] -set-> gmem [M, N]
    
      // Create tiles
      tv2->split(-3, cta_m);
      tv2->split(-2, cta_n);
      tv2->split(-1, getK(macro));
      // [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki]
      tv2->reorder({{-5, -3}, {-3, -2}});
      tv2->axis(0)->parallelize(ParallelType::BIDy);
      tv2->axis(1)->parallelize(ParallelType::BIDx);
    
      TransformPropagator propagator(tv2);
      MaxLogicalDomainInfoSpanningTree(tv2).traverse(&propagator);
      scheduler_utils::parallelizeAllLike(tv2);
    
      // [..., Mi, Ki] -> [..., Ki, Mi]
      tv0c->reorder({{-3, -1}});
      tv0c->applyMmaSwizzleForTMALoad(swizzle);
      tv0c->axis(-6)->parallelize(ParallelType::Unroll);
      // [..., Ni, Ki] -> [..., Ki, Ni]
      tv1c->reorder({{-1, -2}});
      tv1c->applyMmaSwizzleForTMALoad(swizzle);
      tv1c->axis(-6)->parallelize(ParallelType::Unroll);
    
      // Strip ParallelType::Bulk from the broadcast tensors, since its definition
      // is not a TMA
      for (TensorView* tv : {tv0b, tv1b}) {
        for (IterDomain* id : tv->getLoopDomain()) {
          if (id->isBulk()) {
            id->parallelize(ParallelType::Serial);
          }
        }
      }
    
      {
        tv2->split(-3, getM(macro));
        tv2->split(-2, getN(macro));
        // [Mo, No, Ko, Mio, Mii, Nio, Nii, Ki]
        // -> [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki]
        tv2->reorder({{-4, -3}});
        tv2->merge(-5);
        tv2->axis(-4)->parallelize(ParallelType::TIDy);
        scheduler_utils::BoundedDirectionalTransformPropagator::forward(
            tv2,
            -1,
            {tv3},
            scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                .propagateParallelType()
                .propagateToBoundary());
      }
    
      {
        auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
            tv2->getLoopDomain());
        tv2->setAllocationDomain(s.as<IterDomain*>(), true);
        tv2->axis(-1)->parallelize(ParallelType::Mma);
        tv2->axis(-2)->parallelize(ParallelType::Mma);
        tv2->axis(-3)->parallelize(ParallelType::Mma);
      }
    
      if (!use_smem_epilogue) {
        for (auto tv : {tv3c, tv3}) {
          auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
              tv->getLoopDomain());
          tv->setLoopDomain(s.as<IterDomain*>());
        }
        tv3->axis(-1)->parallelize(ParallelType::Vectorize);
      } else {
        auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
            tv3c->getLoopDomain());
        tv3c->setLoopDomain(s.as<IterDomain*>());
        tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
    
        constexpr int64_t stmatrix_tile_m = 16;
        constexpr int64_t stmatrix_tile_n = 16;
        fusion.manage("st_matrix_m_tile", stmatrix_tile_m);
        fusion.manage("st_matrix_n_tile", stmatrix_tile_n);
        fusion.manage("st_matrix_m", getM(macro));
        fusion.manage("st_matrix_n", getN(macro));
    
        MmaInputSmemSwizzle store_swizzle =
            mma_utils::tmaSwizzleSharedMemory(tv3_shmem);
    
        // This internally calls
        // Schedule shared memory cache; Output from StMatrix
        mma_utils::scheduleStMatrixForMmaOutput(
            tv3_shmem, stmatrix_tile_m, stmatrix_tile_n);
    
        // Schedule global memory output; Output from TMA Store
        mma_utils::scheduleTMAStoreForMmaOutput(tv3, store_swizzle);
      }
    
      inlineMost(ir_utils::allTvsExcept(&fusion, {tv0, tv1, tv0b, tv1b}));
    
      if (use_warp_specialization) {
        tv0c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy));
        tv1c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy));
      } else {
        tv0c->circularBuffer(stages, prefetch);
        tv1c->circularBuffer(stages, prefetch);
      }
    
      auto inputs =
          matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype));
      inputs.first = inputs.first.squeeze();
      inputs.second = inputs.second.squeeze();
    
      KernelExecutor ke;
      ke.compile(
          &fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams);
      auto cg_outputs = ke.run({inputs.first, inputs.second});
      auto tref = atMatmul(inputs.first.squeeze(), inputs.second.squeeze(), layout);
      EXPECT_TRUE(at::allclose(cg_outputs[0].as<at::Tensor>(), tref, 1e-5, 1e-5));
    
      // The following check fails if the BroadcastOps are not removed at lowering
      // time, resulting in two intermediate global allocations.
      const kir::KernelSummary& summary = ke.compiledKernel()->kernel()->summary();
      EXPECT_EQ(summary.global_allocations.size(), 0)
          << "Expected to have no intermediate global allocations";
    }
    
    // See https://github.com/NVIDIA/Fuser/issues/3962
    TEST_F(HopperMatmulTest, MLPGemmPersistentBroadcastInputs) {
      EnableOptionsGuard eog;
      EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMultipleMatmuls);
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 8192, N = 8192, K = 8192;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, 1, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
      auto tv2 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv4 = fusedMultiplySum(tv0, tv1, {2});
      auto tv3 = castOp(dtype, tv4);
      fusion.addOutput(tv3);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto a_ref = at::randn({M, 1, K}, options);
      auto b_ref = at::randn({1, N, K}, options);
      auto c_ref = at::randn({1, N, K}, options);
      clearL2Cache();
    
      auto tv3_ref = at::linear(a_ref.squeeze(), b_ref.squeeze());
      clearL2Cache();
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 64);
      gemm_tile.warp_tile = GemmTile(64, 256, 64);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffering_strategy =
          MatmulParams::CircularBufferingStrategy::WarpSpecialized;
      mparams.tiling_strategy =
          MatmulParams::TilingStrategy::DistributeTilesAcrossSMs;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.grid_swizzle_factor = 8;
      // TODO reduced share memory aliasing because of persistent scheduling
      mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
      std::vector<c10::IValue> inputs = {a_ref, b_ref, c_ref};
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(
          cg_outputs[0].as<at::Tensor>().allclose(tv3_ref, 1e-6 * K, 1e-6 * K));
    }
    
    TEST_F(HopperMatmulTest, EpilogueBiasPersistentBroadcastInputs) {
      EnableOptionsGuard eog;
      EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMultipleMatmuls);
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 8192, N = 8192, K = 8192;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, 1, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
      auto tv2 = makeContigConcreteTensor({-1, -1}, dtype); // M, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv3 = fusedMultiplySum(tv0, tv1, {2});
      auto tv4 = add(tv3, tv2);
      auto tv5 = castOp(DataType::BFloat16, tv4);
      fusion.addOutput(tv5);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto t0 = at::randn({M, 1, K}, options);
      auto t1 = at::randn({1, N, K}, options);
      auto t2 = at::randn({M, N}, options);
      auto tv3_ref = at::linear(t0.squeeze(), t1.squeeze(), t2);
    
      std::vector<c10::IValue> inputs = {t0, t1, t2};
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 64);
      gemm_tile.warp_tile = GemmTile(64, 256, 64);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffering_strategy =
          MatmulParams::CircularBufferingStrategy::WarpSpecialized;
      mparams.tiling_strategy =
          MatmulParams::TilingStrategy::DistributeTilesAcrossSMs;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      // TODO reduced share memory aliasing because of persistent scheduling
      mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(
          at::allclose(cg_outputs[0].as<at::Tensor>(), tv3_ref, 5e-2, 5e-2));
    }
    
    TEST_F(HopperMatmulTest, EpilogueSiluPersistentBroadcastInputs) {
      EnableOptionsGuard eog;
      EnableOptionsGuard::getCurOptions().set(EnableOption::FuseMultipleMatmuls);
    
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      constexpr int64_t M = 8192, N = 8192, K = 8192;
      const auto dtype = DataType::BFloat16;
    
      auto tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); // M, 1, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // 1, N, K
      auto tv2 = makeContigConcreteTensor({-1, -1}, dtype); // M, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
      fusion.addInput(tv2);
    
      auto tv3 = fusedMultiplySum(tv0, tv1, {2});
      auto tv4 = castOp(DataType::Float, tv3);
      auto tv5 = neg(tv4);
      auto tv6 = exp(tv5);
      auto tv7 = add(fusion.oneVal(DataType::Float), tv6);
      auto tv8 = reciprocal(tv7);
      auto tv9 = mul(tv4, tv8);
      auto tv10 = mul(tv9, tv2);
      auto tv11 = castOp(DataType::BFloat16, tv10);
      fusion.addOutput(tv11);
    
      auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA);
      auto t0 = at::randn({M, 1, K}, options);
      auto t1 = at::randn({1, N, K}, options);
      auto t2 = at::randn({M, N}, options);
    
      auto tv3_ref = at::linear(t0.squeeze(), t1.squeeze());
      auto tv4_ref = tv3_ref.to(at::kFloat);
      auto tv11_ref =
          (tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * t2).to(at::kBFloat16);
    
      std::vector<c10::IValue> inputs = {t0, t1, t2};
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(128, 256, 64);
      gemm_tile.warp_tile = GemmTile(64, 256, 64);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_256_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffering_strategy =
          MatmulParams::CircularBufferingStrategy::WarpSpecialized;
      mparams.tiling_strategy =
          MatmulParams::TilingStrategy::DistributeTilesAcrossSMs;
      mparams.circular_buffer_options.circular_buffer_smem_write = true;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      // TODO reduced share memory aliasing because of persistent scheduling
      mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {2, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      EXPECT_TRUE(getBankConflictInfo(ke.compiledKernel()->kernel()).empty());
      auto cg_outputs = ke.run(inputs);
      ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(
          ke.compiledKernel()->kernel()));
    
      // Relax tolerance for larger sum due to large K
      EXPECT_TRUE(
          at::allclose(cg_outputs[0].as<at::Tensor>(), tv11_ref, 5e-2, 1e-1));
    }

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Should we parametrize these like we do with the MLPBenchmark tests?

    @rdspring1 rdspring1 force-pushed the matmul_epilogue_test_bcast branch from b132f90 to 3049892 Compare March 19, 2025 16:03
    @rdspring1
    Copy link
    Collaborator Author

    @jacobhinkle

    Should we parametrize these like we do with the MLPBenchmark tests?

    EpilogueBiasPersistentBroadcastInputs AND EpilogueSiluPersistentBroadcastInputs have the MatmulParams that should match nvjet except grid_swizzle. The number of shared memory stages cannot be shared between persistent and data parallel.

    I added FwdEpilogueBiasFusion for a parametrized testing and renamed FwdEpilogueFusion to FwdEpilogueSiluFusion.

    * Inputs are already broadcasted
    * Created EpilogueBiasPersistentBroadcastInputs AND
      EpilogueSiluPersistentBroadcastInputs tests
    Create FwdEpilogueBiasFusion
    Rename FwdEpilogueFusion to FwdEpilogueBiasFusion
    @rdspring1 rdspring1 force-pushed the matmul_epilogue_test_bcast branch from 3049892 to bc7af0f Compare March 20, 2025 16:33
    @rdspring1
    Copy link
    Collaborator Author

    !build

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants