-
Notifications
You must be signed in to change notification settings - Fork 61
auto select between warp specialized and multi-wave approaches #4603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: llu/ws_tma_lnopt
Are you sure you want to change the base?
Conversation
…nto cherry-pick-circular-buffer-params
Co-authored-by: Ryan Spring <[email protected]>
…nvidia/fuser into llu/ws_tma_pingpong_static_warp
Co-authored-by: Ryan Spring <[email protected]>
### Scheduler changes (1) `TIDy` is used to parallelize independent computation warp groups. (2) Revise codegen, ensure different warp groups use different reduction/broadcast workspaces and sync barriers. (3) Other minor changes, e.g. avoid unroll output tensor to save registers, smem buffer size should consider iter grouped number, Unroll "prefetch" is disabled for non-matmul computation branch to avoid instruction cache missing, similar to #3818 ### Heuristic changes **General idea:** Optimize register and shared memory usage to achieve multiple independent compute warp groups, unrolled iteration domains, and deep circular buffering. **Still a rough version just ensures correctness. Will be fine tuned considering other fusions and auto select warp specilized approach or multi-wave approach, e.g. #4603 **Key paras:** Four ints `bdimx`: used to parallelize inner dim, e.g. 128, 256. Influence register usage `bdimy`: used for warp specialization and independent compute warp groups, e.g. 1, 2, 3 `iter_unroll`: unroll factor of iteration dim, e.g. 1, 2, 4 `n_stages`: circular buffer stages, 2, 4, 8 Two bools `bool is_circular_buffer_regs_cached`: Cache TMA loaded buffer to regs `bool is_non_circular_buffer_gmem_to_regs`: Directly load non-circular buffered tv from gmem to regs **Logic to update key paras in func update_heuristics()** Start with `bdimx = 128, bdimy= 1, iter_unroll=1, n_stages=1`, loop until nothing is updated. (1) Try to increase `n_stages` to target, check shared memory usage, `n_stages` won't influence register usage. (2) Try to increase `bdimy` to target, check shared memory and register usage. (3) Try to increase `iter_unroll` to target, check shared memory and register usage. (3) If `bdimy==1`, increase `bdimx`, check shared memory usage **Workflow of the heuristics:** Call `update_heuristics()` with varied configurations. (1) Initial attempt: `is_circular_buffer_regs_cached = true` `is_non_circular_buffer_gmem_to_regs = true` `target_stages = 2, target_bdimy = 2, target_iter_unroll = 2` (2) First fallback when `bdimy == 1`, reduce register usage by set: `is_circular_buffer_regs_cached = false`, `target_iter_unroll = 1` (3) Second fallback when `bdimy == 1`, further reduce register usage by set: `is_non_circular_buffer_gmem_to_regs=false` (4) Last fallback when `n_stages = 1`, reduce shared memory to achieve circular buffering by set `is_non_circular_buffer_gmem_to_regs=true` At last, further increase `target_stages` if there are unused shared memory. ### Performance & Influence of different paras: **After this PR**  After #4599  **Para analysis:** See pages 1 to 7 in this [slide](https://docs.google.com/presentation/d/1Z_4c8dhzy_4Px5WfQ1-zP_uq8PohbYVQojjX2SmSkhY/edit?usp=sharing). --------- Co-authored-by: jjsjann123 <[email protected]>
!test |
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perf looks amazing. 👏
I'm wondering if I mis-read the heuristic logic, or if the benchmark is measured using the right scheduling scheme.
csrc/scheduler/utils.h
Outdated
@@ -848,6 +848,8 @@ TensorView* getUpCastInputOf(const TensorView* buffer_tv); | |||
//! See device_lower/analysis/tensor_producer_aliases.h | |||
TensorView* scheduleInputToSkipIntermediates(TensorView* tv); | |||
|
|||
// Returns true if any of the domains of the tensor is symbolic | |||
bool isConcreteTensor(const TensorView* tv); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick, maybe we can just call this SymbolicTensor
? since the code comment is saying that already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was initially using isSymbolicTensor()
, but then noticed that "symbolic" has a different meaning in IterType::Symbolic
. So I switched to isConcreteTensor
. After your suggestion, it seems reasonable to go back to using isSymbolicTensor()
, since we're checking a tensor not an iter domain.
We also use the term "symbolic" in IterType::Symbolic, which refers to a temporary state during fusion definition and compilation. This state is later resolved to either IterType::Iteration
or IterType::Broadcast
during concretization.
In this case, we're explicitly referring to SymbolicTensor, so it should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean. yeah, concretization took out so many good names 😆
Thanks for bearing with my nitpicking.
// static CTA size | ||
auto inp_tvs = ir_utils::filterByType<TensorView>(fusion->inputs()); | ||
if (std::any_of(inp_tvs.begin(), inp_tvs.end(), [](TensorView* tv) { | ||
return scheduler_utils::isConcreteTensor(tv); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic looks wrong to me.
return scheduler_utils::isConcreteTensor(tv); | |
return !scheduler_utils::isConcreteTensor(tv); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to isSymbolicTensor
, so no need to change the logic here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic was wrong, becuase I changed the function name from isSymbolicTensor
to isConcreteTensor
without actually changing the logic. Now we are changing back to isSymbolicTensor
.
runtime_info.getIndexType()); | ||
|
||
// If warp specialized is enabled, or the heuristic is successful, return | ||
if (hp.is_warp_specialized || rparams->is_good_ws_heuristic) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out of curiosity, what happens when we have hp.is_warp_specialized
set as true, but we failed to get a good heuristics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It leads to poor performance but still gives correct results. Useful for comparing the performance of different approaches.
Co-authored-by: jjsjann123 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Summary
This PR introduces automatic selection between warp specialized and multi-wave approaches for normalization kernels in the
InnerOuterPersistentKernelScheduler
. The scheduler now intelligently chooses the optimal approach based on hardware capabilities, input characteristics, and workload properties.Key Changes
1. Enhanced Heuristic Selection Logic
The
getInnerOuterPersistentHeuristics()
function now implements a two-tier selection strategy:User-Requested Selection
EnableOption::WarpSpecializedNormalization
is explicitly enabled, the scheduler unconditionally uses the warp specialized approachAutomatic Heuristic-Based Selection
The scheduler automatically generates warp specialized heuristics when all the following conditions are met:
2. Fallback Mechanism
EnableOption::WarpSpecializedNormalization
is explicitly enabled.is_good_ws_heuristic()
returnsfalse
. The following heuristics are considered 'bad'.n_stages == 1
, the heuristic cannot achieve circular buffering and is rejectedbdimy == 1 && is_non_circular_buffer_gmem_to_regs
, Ping-pong is not used and the heuristic aims to reduce shared memory usage by loading data directly from global memory to registers. This increases register pressure and may lead to register spills. The method is considered beneficial only when there are at least 64 non-buffer registers available, to avoid excessive spilling. This threshold is based on empirical results from RMSNorm backward pass in FP16 on B200, with a practical cut-off around a hidden size of 24K.3. Implementation Details
preferWarpSpecialized()
implements the heuristic logic for automatic selectionSchedulerHyperParameters
now includesis_warp_specialized
flagPerformance Impact