-
Notifications
You must be signed in to change notification settings - Fork 61
fix warp specialized tma for ln bwd #4663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
!test |
Review updated until commit d33dc5c Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's with the perf regression with hidden size around 25k?! That looked pretty bad.
} else { | ||
cached_tv->axis(last_iter_dim)->parallelize(ParallelType::Unroll); | ||
// skip tvs that are already vectorized | ||
if (cached_tv->axis(-1)->getParallelType() == ParallelType::Vectorize) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: is it intentional to check axis(-1)
? Since our schedule below is doing cached_tv->axis(last_iter_dim)->parallelize(ParallelType::Vectorize);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, extended comments to:
// skip tvs that are already vectorized in general vectorization
// analysis and propagation.
@@ -325,11 +325,10 @@ std::vector<TensorView*> getGroupedReductionPersistentTvs( | |||
} | |||
} | |||
for (auto tv : p_of_reductions) { | |||
// must exists in both set, not same as inner_bcast_tv, and | |||
// must exists in both set, and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: format.
// TODO: Fix auto validation - the fusion runs but testValidate has type | ||
// conversion issues testValidate(&fusion_copy, cg_outputs, args, | ||
// expected_outputs, | ||
// __LINE__, __FILE__); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
errr, the reference implementation looks painful.
What's the issue with the validation? is it type inference that went wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The err msg is Result is dynamic but not convertible to result type
, I created an issue at #4679
No need to worry about that. The performance was measured by enforcing the use of the warp-specialized approach. By default, this approach is not yet enabled. An upper-level heuristic is in place (see #4603) to choose between the warp-specialized and multi-wave approaches, it will select multi-wave approach when the hidden size is larger than 24K and this avoids regressions. |
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!test --dev |
!test --dev |
!test |
Summary
This PR addresses a remaining inline issue for layer norm backward operations generated by Thunder, building upon the previous fix in #4561.
Issue
The

getGroupedReductionPersistentTvs()
function was designed to identify persistent tensors for grouped reduction operations. However, it only examined cached input tensors that already had inner broadcast dimensions (e.g.,input [Iter,1] --> tv1 [Iter,1]
,T4 --> T73
in the following figure). This approach missed tensors that initially lack broadcast dimensions but are later consumed by broadcast operations (e.g.,input [Iter] --> tv1 [Iter] --broadcast()--> tv2[Iter,1]
). This limitation specifically affects layer norm backward operations generated by Thunder, as illustrated by tensorsT70
andT69
in the provided figure.Red colored reduction path: T70 --> T85
Blue colored post-reduction path: T70 --> T44
Fix
Modified the loop logic to iterate through all cached tensors, only skipping those that are either TMA-loaded or already vectorized.
Additional Changes
Slightly adjusted heuristics to enhance performance for layer norm backward operations when ping-pong buffering is not utilized.
Tests
The fusion definition generated by thunder is converted to a cpp test.
Performance
(1) Compare layer norm bwd with multi-wave approach


(2) Compare RMS norm bwd with current main branch using user enforced warp specilized approach.