Skip to content

auto select between warp specialized and multi-wave approaches #4603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 80 commits into
base: llu/ws_tma_lnopt
Choose a base branch
from

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jun 9, 2025

Summary

This PR introduces automatic selection between warp specialized and multi-wave approaches for normalization kernels in the InnerOuterPersistentKernelScheduler. The scheduler now intelligently chooses the optimal approach based on hardware capabilities, input characteristics, and workload properties.

Key Changes

1. Enhanced Heuristic Selection Logic

The getInnerOuterPersistentHeuristics() function now implements a two-tier selection strategy:

User-Requested Selection

  • When EnableOption::WarpSpecializedNormalization is explicitly enabled, the scheduler unconditionally uses the warp specialized approach
  • This provides users with direct control over the scheduling strategy

Automatic Heuristic-Based Selection

The scheduler automatically generates warp specialized heuristics when all the following conditions are met:

  1. GPU Architecture ≥ 10: Multi-wave approach works well on Hopper GPUs, so warp specialization is only auto enabled for Blackwell and later architectures
  2. Concretized Input Tensors: Current implementation requires static shapes for warp specialization (dynamic inputs are not yet supported)
  3. Sufficient Iteration Domain Size:
    • RMS Norm Backward: Requires > 4 × SM_count rows per SM
    • Layer Norm Backward: Requires > 16 × SM_count rows per SM
    • This ensures deep circular buffering and amortizes weight tensor (shared in different batches) loading overhead

2. Fallback Mechanism

  • No fallback if EnableOption::WarpSpecializedNormalization is explicitly enabled.
  • For auto generated, will fallback to multi-wave approach if is_good_ws_heuristic() returns false. The following heuristics are considered 'bad'.
    • Single Stage Detection: If n_stages == 1, the heuristic cannot achieve circular buffering and is rejected
    • Register Spill Prevention: if bdimy == 1 && is_non_circular_buffer_gmem_to_regs, Ping-pong is not used and the heuristic aims to reduce shared memory usage by loading data directly from global memory to registers. This increases register pressure and may lead to register spills. The method is considered beneficial only when there are at least 64 non-buffer registers available, to avoid excessive spilling. This threshold is based on empirical results from RMSNorm backward pass in FP16 on B200, with a practical cut-off around a hidden size of 24K.

3. Implementation Details

  • New Function: preferWarpSpecialized() implements the heuristic logic for automatic selection
  • Enhanced Parameters: SchedulerHyperParameters now includes is_warp_specialized flag
  • Graceful Degradation: Failed warp specialized attempts seamlessly transition to multi-wave scheduling

Performance Impact

image

liqiangxl added a commit that referenced this pull request Jun 18, 2025
### Scheduler changes
(1) `TIDy` is used to parallelize independent computation warp groups.
(2) Revise codegen, ensure different warp groups use different
reduction/broadcast workspaces and sync barriers.
(3) Other minor changes, e.g. avoid unroll output tensor to save
registers, smem buffer size should consider iter grouped number, Unroll
"prefetch" is disabled for non-matmul computation branch to avoid
instruction cache missing, similar to
#3818

### Heuristic changes
**General idea:**
Optimize register and shared memory usage to achieve multiple
independent compute warp groups, unrolled iteration domains, and deep
circular buffering.

**Still a rough version just ensures correctness. Will be fine tuned
considering other fusions and auto select warp specilized approach or
multi-wave approach, e.g. #4603

**Key paras:**
Four ints
`bdimx`: used to parallelize inner dim, e.g. 128, 256. Influence
register usage
`bdimy`: used for warp specialization and independent compute warp
groups, e.g. 1, 2, 3
`iter_unroll`: unroll factor of iteration dim, e.g. 1, 2, 4
`n_stages`: circular buffer stages, 2, 4, 8
Two bools
  `bool is_circular_buffer_regs_cached`: Cache TMA loaded buffer to regs
`bool is_non_circular_buffer_gmem_to_regs`: Directly load non-circular
buffered tv from gmem to regs

 **Logic to update key paras in func update_heuristics()**
Start with `bdimx = 128, bdimy= 1, iter_unroll=1, n_stages=1`, loop
until nothing is updated.
(1) Try to increase `n_stages` to target, check shared memory usage,
`n_stages` won't influence register usage.
(2) Try to increase `bdimy` to target, check shared memory and register
usage.
(3) Try to increase `iter_unroll` to target, check shared memory and
register usage.
(3) If `bdimy==1`, increase `bdimx`, check shared memory usage

 **Workflow of the heuristics:**
Call `update_heuristics()` with varied configurations.
(1) Initial attempt:
`is_circular_buffer_regs_cached = true`
`is_non_circular_buffer_gmem_to_regs = true`
`target_stages = 2, target_bdimy = 2, target_iter_unroll = 2`

(2) First fallback when `bdimy == 1`, reduce register usage by set:
`is_circular_buffer_regs_cached = false`, `target_iter_unroll = 1`

(3) Second fallback when `bdimy == 1`, further reduce register usage by
set: `is_non_circular_buffer_gmem_to_regs=false`

(4) Last fallback when `n_stages = 1`, reduce shared memory to achieve
circular buffering by set `is_non_circular_buffer_gmem_to_regs=true`

At last, further increase `target_stages` if there are unused shared
memory.

### Performance & Influence of different paras:

**After this PR**

![image](https://github.com/user-attachments/assets/cb07cf2c-f8b0-4d9d-9e08-e676049e463c)

After #4599 

![image](https://github.com/user-attachments/assets/4d5807bd-367a-4e2e-b397-f3afc129009a)

**Para analysis:**
See pages 1 to 7 in this
[slide](https://docs.google.com/presentation/d/1Z_4c8dhzy_4Px5WfQ1-zP_uq8PohbYVQojjX2SmSkhY/edit?usp=sharing).

---------

Co-authored-by: jjsjann123 <[email protected]>
Base automatically changed from llu/ws_tma_pingpong_static_warp to main June 18, 2025 18:16
@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl liqiangxl changed the base branch from main to llu/ws_tma_lnopt June 24, 2025 15:05
@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl liqiangxl marked this pull request as ready for review June 24, 2025 18:16
@liqiangxl liqiangxl requested a review from jjsjann123 June 24, 2025 18:16
Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perf looks amazing. 👏

I'm wondering if I mis-read the heuristic logic, or if the benchmark is measured using the right scheduling scheme.

@@ -848,6 +848,8 @@ TensorView* getUpCastInputOf(const TensorView* buffer_tv);
//! See device_lower/analysis/tensor_producer_aliases.h
TensorView* scheduleInputToSkipIntermediates(TensorView* tv);

// Returns true if any of the domains of the tensor is symbolic
bool isConcreteTensor(const TensorView* tv);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick, maybe we can just call this SymbolicTensor? since the code comment is saying that already.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was initially using isSymbolicTensor(), but then noticed that "symbolic" has a different meaning in IterType::Symbolic. So I switched to isConcreteTensor. After your suggestion, it seems reasonable to go back to using isSymbolicTensor(), since we're checking a tensor not an iter domain.

We also use the term "symbolic" in IterType::Symbolic, which refers to a temporary state during fusion definition and compilation. This state is later resolved to either IterType::Iteration or IterType::Broadcast during concretization.
In this case, we're explicitly referring to SymbolicTensor, so it should be fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. yeah, concretization took out so many good names 😆

Thanks for bearing with my nitpicking.

// static CTA size
auto inp_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
if (std::any_of(inp_tvs.begin(), inp_tvs.end(), [](TensorView* tv) {
return scheduler_utils::isConcreteTensor(tv);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic looks wrong to me.

Suggested change
return scheduler_utils::isConcreteTensor(tv);
return !scheduler_utils::isConcreteTensor(tv);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to isSymbolicTensor, so no need to change the logic here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic was wrong, becuase I changed the function name from isSymbolicTensor to isConcreteTensor without actually changing the logic. Now we are changing back to isSymbolicTensor.

runtime_info.getIndexType());

// If warp specialized is enabled, or the heuristic is successful, return
if (hp.is_warp_specialized || rparams->is_good_ws_heuristic) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, what happens when we have hp.is_warp_specialized set as true, but we failed to get a good heuristics?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It leads to poor performance but still gives correct results. Useful for comparing the performance of different approaches.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants