Skip to content

fix warp specialized tma for ln bwd #4663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 27, 2025
Merged

fix warp specialized tma for ln bwd #4663

merged 9 commits into from
Jun 27, 2025

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jun 24, 2025

Summary

This PR addresses a remaining inline issue for layer norm backward operations generated by Thunder, building upon the previous fix in #4561.

Issue

The getGroupedReductionPersistentTvs() function was designed to identify persistent tensors for grouped reduction operations. However, it only examined cached input tensors that already had inner broadcast dimensions (e.g., input [Iter,1] --> tv1 [Iter,1], T4 --> T73 in the following figure). This approach missed tensors that initially lack broadcast dimensions but are later consumed by broadcast operations (e.g., input [Iter] --> tv1 [Iter] --broadcast()--> tv2[Iter,1]). This limitation specifically affects layer norm backward operations generated by Thunder, as illustrated by tensors T70 and T69 in the provided figure.
image
Red colored reduction path: T70 --> T85
Blue colored post-reduction path: T70 --> T44

Fix

Modified the loop logic to iterate through all cached tensors, only skipping those that are either TMA-loaded or already vectorized.

Additional Changes

Slightly adjusted heuristics to enhance performance for layer norm backward operations when ping-pong buffering is not utilized.

Tests

The fusion definition generated by thunder is converted to a cpp test.

Performance

(1) Compare layer norm bwd with multi-wave approach
image
(2) Compare RMS norm bwd with current main branch using user enforced warp specilized approach.
image

@liqiangxl
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Jun 24, 2025

Review updated until commit d33dc5c

Description

  • Enhanced getGroupedReductionPersistentTvs to consider all cached tensors.

  • Improved heuristics for bdimx increase in getHeuristics.

  • Added a test for layer norm backward generated by Thunder.


Changes walkthrough 📝

Relevant files
Enhancement
normalization_inner_outer_tma_ws.cpp
Improve bdimx increase and tensor handling                             

csrc/scheduler/normalization_inner_outer_tma_ws.cpp

  • Enhanced logic for increasing bdimx in getHeuristics.
  • Updated tensor handling in scheduleFusion for grouped reductions.
  • +38/-29 
    normalization_inner_outer_utils.cpp
    Update grouped reduction tensor selection                               

    csrc/scheduler/normalization_inner_outer_utils.cpp

  • Modified getGroupedReductionPersistentTvs to consider all cached
    tensors.
  • +5/-7     
    Tests
    test_combined_inner_outer_reduction.cpp
    Add layer norm backward test                                                         

    tests/cpp/test_combined_inner_outer_reduction.cpp

  • Added a test for layer norm backward operations generated by Thunder.
  • +157/-0 
    Documentation
    normalization_inner_outer_utils.h
    Update function parameter names                                                   

    csrc/scheduler/normalization_inner_outer_utils.h

  • Updated function parameter names in getGroupedReductionPersistentTvs.
  • +2/-2     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Heuristics

    The heuristics for increasing bdimx and iter_unroll should be validated to ensure they do not lead to performance regressions or excessive resource usage.

        is_enough_smem(iter_unroll * 2, n_stages, bdimx, bdimy) &&
        outer_dim_numel % (iter_unroll * 2) == 0) {
      is_updated = true;
      iter_unroll *= 2;
    }
    
    // consider increasing bdimx only when pingpong is not used.
    if (bdimy == 1) {
      int64_t new_bdimx = bdimx + 128;
      // ensure new bdimx is within bounds and smem is enough
      bool can_increase = (new_bdimx <= max_bdimx) &&
          is_enough_smem(iter_unroll, n_stages, new_bdimx, bdimy);
      auto get_tail = [](int64_t a, int64_t b) {
        return a % b == 0 ? b : a % b;
      };
      // try to increase bdimx only when:
      // (1) Benifical from more register usage. When bdimx is 128, only 128 x
      //     256 registers are used, should increase to use all 64K registers.
      // (2) Benificial from divisible split.
      // (3) Current bdimx leads to register spills.
      bool try_increase = (bdimx == 128) ||
          (get_tail(after_vect, new_bdimx) >= get_tail(after_vect, bdimx)) ||
          (!is_enough_regs(iter_unroll, bdimx, bdimy));
      if (can_increase && try_increase) {
        is_updated = true;
        bdimx += 128;
      }
    Tensor Skipping Logic

    The logic for skipping tensors that are TMA-loaded or already vectorized should be reviewed to ensure it correctly identifies and handles all relevant cases.

    std::vector<TensorView*> getGroupedReductionPersistentTvs(
        Fusion* fusion,
        TensorView* cached_input,
        const std::vector<TensorView*>& reduction_tvs) {
      std::vector<TensorView*> res;
      // Get all fusion outputs that are consumers of reduction tvs
      const auto& reduction_to_output = DependencyCheck::getAllOutputsOf(
          {reduction_tvs.begin(), reduction_tvs.end()});
      std::unordered_set<TensorView*> p_of_reductions;
      std::unordered_set<TensorView*> c_of_reductions;
      for (auto output : reduction_to_output) {
        auto chains_to_output =
            DependencyCheck::getAllDependencyChains(cached_input, output);
        for (auto chain : chains_to_output) {
    Test Validation

    The test ThunderLayerNormBackward should be further validated to ensure the fusion runs correctly and the expected outputs match the actual outputs, especially considering the type conversion issues mentioned in the TODO.

    TEST_F(CombinedSchedulerTest, ThunderLayerNormBackward) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
      EnableOptionsGuard opt_guard;
      EnableOptionsGuard::getCurOptions().set(
          EnableOption::WarpSpecializedNormalization);
      std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
      auto fusion = fusion_ptr.get();
      FusionGuard fg(fusion);
    
      // Define constants for dimensions
      const int64_t dim0 = 16384;
      const int64_t dim1 = 2048;
      auto dim0_val = IrBuilder::create<Val>(dim0);
      auto dim1_val = IrBuilder::create<Val>(dim1);
      auto one_val = IrBuilder::create<Val>(1);
    
      {
        auto tv0 = makeContigConcreteTensor({dim1}, DataType::BFloat16);
        fusion->addInput(tv0);
        auto tv1 = makeContigConcreteTensor({dim0}, DataType::Float);
        fusion->addInput(tv1);
        auto tv2 = makeContigConcreteTensor({dim0, dim1}, DataType::BFloat16);
        fusion->addInput(tv2);
        auto tv3 = makeContigConcreteTensor({dim0, dim1}, DataType::BFloat16);
        fusion->addInput(tv3);
        auto tv4 = makeContigConcreteTensor({dim0, 1}, DataType::Float);
        fusion->addInput(tv4);
        auto tv8 = expand(broadcast(tv0, {true, false}), {dim0_val, dim1_val});
        auto tv12 = expand(broadcast(tv1, {false, true}), {dim0_val, one_val});
        auto tv13 = castOp(DataType::Float, tv2);
        auto tv14 = castOp(DataType::Float, tv8);
        auto tv18 = expand(broadcast(tv12, {false, false}), {dim0_val, dim1_val});
        auto tv19 = castOp(DataType::Float, tv3);
        auto tv20 = mul(tv14, tv13);
        auto tv21 = sub(tv19, tv18);
        auto tv22 = mul(tv21, tv20);
        auto tv23 = sum(tv22, {1}, false);
        auto tv27 = expand(broadcast(tv23, {false, true}), {dim0_val, one_val});
        auto tv31 = expand(broadcast(tv4, {false, false}), {dim0_val, dim1_val});
        auto s32 = IrBuilder::create<Val>(3.0, DataType::Double);
        auto tv33 = pow(tv4, s32);
        auto s34 = IrBuilder::create<Val>(-0.5, DataType::Double);
        auto tv35 = mul(s34, tv27);
        auto tv36 = mul(tv31, tv20);
        auto tv37 = mul(tv35, tv33);
        auto tv38 = neg(tv36);
        auto tv39 = sum(tv37, {1}, false);
        auto tv40 = sum(tv38, {1}, false);
        auto tv44 = expand(broadcast(tv1, {false, true}), {dim0_val, one_val});
        auto tv48 = expand(broadcast(tv39, {false, true}), {dim0_val, one_val});
        auto tv52 = expand(broadcast(tv40, {false, true}), {dim0_val, one_val});
        auto tv56 = expand(broadcast(tv44, {false, false}), {dim0_val, dim1_val});
        auto tv60 = expand(broadcast(tv48, {false, false}), {dim0_val, dim1_val});
        auto tv61 = sum(tv52, {1}, false);
        auto tv62 = sub(tv19, tv56);
        auto s63 = IrBuilder::create<Val>(2.0, DataType::Double);
        auto tv64 = mul(s63, tv60);
        auto tv68 = expand(broadcast(tv61, {false, true}), {dim0_val, one_val});
        auto tv69 = mul(tv64, tv62);
        auto tv73 = expand(broadcast(tv68, {false, false}), {dim0_val, dim1_val});
        auto s74 = IrBuilder::create<Val>(2048.0, DataType::Double);
        auto s75 = reciprocal(s74);
        auto tv76 = mul(tv69, s75);
        auto s77 = IrBuilder::create<Val>(0.000488281, DataType::Double);
        auto tv78 = mul(s77, tv73);
        auto tv79 = mul(tv21, tv31);
        auto tv80 = add(tv78, tv76);
        auto tv81 = mul(tv79, tv13);
        auto tv82 = add(tv36, tv80);
        auto tv83 = sum(tv81, {0}, false);
        auto tv84 = sum(tv13, {0}, false);
        auto tv85 = castOp(DataType::BFloat16, tv82);
        auto tv86 = castOp(DataType::BFloat16, tv83);
        auto tv87 = castOp(DataType::BFloat16, tv84);
        fusion->addOutput(tv87);
        fusion->addOutput(tv86);
        fusion->addOutput(tv85);
      }
    
      auto fusion_copy = *fusion_ptr;
      auto options_fp32 =
          at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      auto options_fp16 =
          at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0);
      auto t0 = at::randn({dim1}, options_fp16);
      auto t1 = at::randn({dim0}, options_fp32);
      auto t2 = at::randn({dim0, dim1}, options_fp16);
      auto t3 = at::randn({dim0, dim1}, options_fp16);
      auto t4 = at::randn({dim0, 1}, options_fp32);
      KernelArgumentHolder args = {t0, t1, t2, t3, t4};
      FusionExecutorCache executor_cache(std::move(fusion_ptr));
      auto cg_outputs = executor_cache.runFusionWithInputs(args);
    
      // Generate expected outputs using ATen computations
      auto t0_fp32 = t0.to(at::kFloat);
      auto t2_fp32 = t2.to(at::kFloat);
      auto t3_fp32 = t3.to(at::kFloat);
      auto t4_fp32 = t4.to(at::kFloat);
    
      // Step-by-step computation matching the fusion
      auto tv8 = t0_fp32.unsqueeze(0).expand({dim0, dim1}); // broadcast t0
      auto tv12 = t1.unsqueeze(1).expand({dim0, 1}); // broadcast t1
      auto tv13 = t2_fp32; // cast t2 to float
      auto tv14 = tv8; // cast tv8 to float
      auto tv18 = tv12.expand({dim0, dim1}); // expand tv12
      auto tv19 = t3_fp32; // cast t3 to float
      auto tv20 = tv14 * tv13; // mul(tv14, tv13)
      auto tv21 = tv19 - tv18; // sub(tv19, tv18)
      auto tv22 = tv21 * tv20; // mul(tv21, tv20)
      auto tv23 = tv22.sum(1, false); // sum(tv22, {1})
      auto tv27 = tv23.unsqueeze(1).expand({dim0, 1}); // broadcast tv23
      auto tv31 = t4_fp32.expand({dim0, dim1}); // expand t4
      auto tv33 = t4_fp32.pow(3.0); // pow(t4, 3.0)
      auto tv35 = -0.5 * tv27; // mul(-0.5, tv27)
      auto tv36 = tv31 * tv20; // mul(tv31, tv20)
      auto tv37 = tv35 * tv33; // mul(tv35, tv33)
      auto tv38 = -tv36; // neg(tv36)
      auto tv39 = tv37.sum(1, false); // sum(tv37, {1})
      auto tv40 = tv38.sum(1, false); // sum(tv38, {1})
      auto tv44 = t1.unsqueeze(1).expand({dim0, 1}); // broadcast t1
      auto tv48 = tv39.unsqueeze(1).expand({dim0, 1}); // broadcast tv39
      auto tv52 = tv40.unsqueeze(1).expand({dim0, 1}); // broadcast tv40
      auto tv56 = tv44.expand({dim0, dim1}); // expand tv44
      auto tv60 = tv48.expand({dim0, dim1}); // expand tv48
      auto tv61 = tv52.sum(1, false); // sum(tv52, {1})
      auto tv62 = tv19 - tv56; // sub(tv19, tv56)
      auto tv64 = 2.0 * tv60; // mul(2.0, tv60)
      auto tv68 = tv61.unsqueeze(1).expand({dim0, 1}); // broadcast tv61
      auto tv69 = tv64 * tv62; // mul(tv64, tv62)
      auto tv73 = tv68.expand({dim0, dim1}); // expand tv68
      auto tv75 = 1.0 / 2048.0; // reciprocal(2048.0)
      auto tv76 = tv69 * tv75; // mul(tv69, tv75)
      auto tv77 = 0.000488281; // constant
      auto tv78 = tv77 * tv73; // mul(tv77, tv73)
      auto tv79 = tv21 * tv31; // mul(tv21, tv31)
      auto tv80 = tv78 + tv76; // add(tv78, tv76)
      auto tv81 = tv79 * tv13; // mul(tv79, tv13)
      auto tv82 = tv36 + tv80; // add(tv36, tv80)
      auto tv83 = tv81.sum(0, false); // sum(tv81, {0})
      auto tv84 = tv13.sum(0, false); // sum(tv13, {0})
    
      // Expected outputs (cast to BFloat16)
      auto expected_output0 = tv84.to(at::kBFloat16); // tv87
      auto expected_output1 = tv83.to(at::kBFloat16); // tv86
      auto expected_output2 = tv82.to(at::kBFloat16); // tv85
    
      std::vector<at::Tensor> expected_outputs = {
          expected_output0, expected_output1, expected_output2};
    
      testValidate(
          &fusion_copy, cg_outputs, args, expected_outputs, __LINE__, __FILE__);
    
      // TODO: Fix auto validation - the fusion runs but testValidate has type

    @liqiangxl liqiangxl marked this pull request as ready for review June 24, 2025 15:01
    @liqiangxl liqiangxl requested a review from jjsjann123 June 24, 2025 15:01
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What's with the perf regression with hidden size around 25k?! That looked pretty bad.

    } else {
    cached_tv->axis(last_iter_dim)->parallelize(ParallelType::Unroll);
    // skip tvs that are already vectorized
    if (cached_tv->axis(-1)->getParallelType() == ParallelType::Vectorize) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    QQ: is it intentional to check axis(-1)? Since our schedule below is doing cached_tv->axis(last_iter_dim)->parallelize(ParallelType::Vectorize);

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    yes, extended comments to:

            // skip tvs that are already vectorized in general vectorization
            // analysis and propagation.
    

    @@ -325,11 +325,10 @@ std::vector<TensorView*> getGroupedReductionPersistentTvs(
    }
    }
    for (auto tv : p_of_reductions) {
    // must exists in both set, not same as inner_bcast_tv, and
    // must exists in both set, and
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nit: format.

    Comment on lines +1540 to +1543
    // TODO: Fix auto validation - the fusion runs but testValidate has type
    // conversion issues testValidate(&fusion_copy, cg_outputs, args,
    // expected_outputs,
    // __LINE__, __FILE__);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    errr, the reference implementation looks painful.

    What's the issue with the validation? is it type inference that went wrong?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The err msg is Result is dynamic but not convertible to result type, I created an issue at #4679

    @liqiangxl
    Copy link
    Collaborator Author

    What's with the perf regression with hidden size around 25k?! That looked pretty bad.

    No need to worry about that. The performance was measured by enforcing the use of the warp-specialized approach. By default, this approach is not yet enabled. An upper-level heuristic is in place (see #4603) to choose between the warp-specialized and multi-wave approaches, it will select multi-wave approach when the hidden size is larger than 24K and this avoids regressions.

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    @liqiangxl
    Copy link
    Collaborator Author

    !test --dev

    @liqiangxl
    Copy link
    Collaborator Author

    !test --dev

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl merged commit b54e1fa into main Jun 27, 2025
    50 of 52 checks passed
    @liqiangxl liqiangxl deleted the llu/ws_tma_lnopt branch June 27, 2025 11:37
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants