-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix indexing of take_along_axis #4128
Conversation
@@ -567,8 +569,8 @@ TEST_F(GatherTest, TakeAlongAxisIntermediateTensorPointwise2) { | |||
Fusion& fusion = *fusion_ptr.get(); | |||
FusionGuard fg(&fusion); | |||
|
|||
EnableOptionsGuard opt_guard; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was necessary before but not anymore since the NVFuserTest
fixture itself has one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the take_along_axis
tests are enabled. Those using gather
still needs more work.
csrc/bfs.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added some shortcuts
Review updated until commit 1466576 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
@@ -258,9 +259,6 @@ class AllocationDomainSetup : private kir::IrVisitor { | |||
for (const auto i : c10::irange(tv->nDims())) { | |||
auto loop_id = tv->getLoopDomain().at(i); | |||
auto pt = loop_id->getParallelType(); | |||
if (!mayRequireAllocation(tv, loop_id)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would drop broadcast IDs as they wouldn't contribute to the final allocation of the tensor, but they would be still required for the analysis of patchAllocationOfIndexedProducerTensor
as it uses IRBFS
. Note that such broadcast IDs are eventually removed anyway at the end of this function.
!test |
!test |
!test |
!test --diff |
The codediff changes are expected as the new indexer still lacks the support of magic zero. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left some nitpicks
if (std::find(allocation_ids.begin(), allocation_ids.end(), indexed_id) != | ||
allocation_ids.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
!std::ranges::contains(allocation_ids, indexed_id)
@@ -729,6 +749,93 @@ class AllocationDomainSetup : private kir::IrVisitor { | |||
return patched_allocation_domains; | |||
} | |||
|
|||
// If a producer tensor is accessed through supplied indices, the | |||
// indexed logical IDs need to be entirely allocated. | |||
std::optional<std::vector<IterDomain*>> patchAllocationOfIndexedProducerTensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Do we need std::optional
? Could we just return an empty vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it might be ambiguous because an empty vector could mean the input allocation domain shouldn't be changed or the allocation domain itself could be empty, so using std::optional
would make it more explicit.
NVF_ERROR(std::all_of(contiguity.begin(), contiguity.end(), [](auto b) { | ||
return b.has_value() && b.value(); | ||
})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
NVF_ERROR(std::ranges::all_of(contiguity, [](auto b) {
return b && *b;
}));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just my personal preference, but I slightly prefer more explicit style.
if (std::find( | ||
dependent_allocation_ids.begin(), | ||
dependent_allocation_ids.end(), | ||
id) != dependent_allocation_ids.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
!std::ranges::contains(dependent_allocation_ids, id)
Previously, in TensorIndexer, indexing indirectly accessed tensors in ops like take_along_axis wasn't working when the tensor is not entirely allocated (in global memory). It was due to the allocation domain not properly set. Specifically, since the indirectly accessed logical ID is indexed, it must be included in its allocation domain. The legacy indexer somehow works, but the new indexer resulted in a failure with the index path traversal. This PR fixes the allocation domain.
Note that
gather
is not yet supported in TensorIndexer when it is not exact.