Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix indexing of take_along_axis #4128

Merged
merged 23 commits into from
Mar 26, 2025
Merged

Fix indexing of take_along_axis #4128

merged 23 commits into from
Mar 26, 2025

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Mar 21, 2025

Previously, in TensorIndexer, indexing indirectly accessed tensors in ops like take_along_axis wasn't working when the tensor is not entirely allocated (in global memory). It was due to the allocation domain not properly set. Specifically, since the indirectly accessed logical ID is indexed, it must be included in its allocation domain. The legacy indexer somehow works, but the new indexer resulted in a failure with the index path traversal. This PR fixes the allocation domain.

Note that gather is not yet supported in TensorIndexer when it is not exact.

@naoyam naoyam changed the base branch from main to alloc_refactor March 21, 2025 22:02
@@ -567,8 +569,8 @@ TEST_F(GatherTest, TakeAlongAxisIntermediateTensorPointwise2) {
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

EnableOptionsGuard opt_guard;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was necessary before but not anymore since the NVFuserTest fixture itself has one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the take_along_axis tests are enabled. Those using gather still needs more work.

csrc/bfs.h Outdated
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added some shortcuts

Copy link

github-actions bot commented Mar 21, 2025

Review updated until commit 1466576

Description

  • Added patchAllocationOfIndexedProducerTensor to handle allocation of indexed producer tensors.

  • Included bfs.h for BFS utilities.

  • Updated tests to enable IdModel option for take_along_axis tests.


Changes walkthrough 📝

Relevant files
Enhancement
allocation.cpp
Add indexed producer tensor allocation handling                   

csrc/device_lower/pass/allocation.cpp

  • Included bfs.h for BFS utilities.
  • Added patchAllocationOfIndexedProducerTensor method.
  • Updated allocation domain setup to handle indexed producer tensors.
  • +116/-9 
    bfs.h
    Add BFS template overloads                                                             

    csrc/bfs.h

  • Added template overloads for getInputsOfExpr, getOutputsOfExpr,
    getInputsOfExprPath, and getOutputsOfExprPath.
  • +52/-1   
    Tests
    test_gather.cpp
    Enable IdModel in take_along_axis tests                                   

    tests/cpp/test_gather.cpp

    • Enabled IdModel option in multiple take_along_axis tests.
    +16/-10 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Complexity

    The new function patchAllocationOfIndexedProducerTensor is quite complex and involves multiple steps. It would be beneficial to add more comments explaining the logic and purpose of each step.

    // indexed logical IDs need to be entirely allocated.
    std::optional<std::vector<IterDomain*>> patchAllocationOfIndexedProducerTensor(
        const TensorView* tv,
        const std::vector<IterDomain*>& allocation_ids) const {
      VectorOfUniqueEntries<Val*> indexed_logical_ids;
      for (auto use_expr : tv->uses()) {
        auto indexed_id = ir_utils::getIndexedProducerID(use_expr);
        if (indexed_id == nullptr ||
            std::find(
                tv->getLogicalDomain().begin(),
                tv->getLogicalDomain().end(),
                indexed_id) == tv->getLogicalDomain().end()) {
          continue;
        }
    
        // This indexed_id is indirectly accessed and needs to be
        // allocated entirely.
    
        // If it's already in the allocation ID set, nothing further
        // needs to be done
        if (std::find(allocation_ids.begin(), allocation_ids.end(), indexed_id) !=
            allocation_ids.end()) {
          continue;
        }
    
        indexed_logical_ids.pushBack(indexed_id);
      }
    
      if (indexed_logical_ids.empty()) {
        return std::nullopt;
      }
    
      // indexed_logical_ids is not in the current allocation ID
      // list. Find the allocation IDs that are equivalent to the
      // indexed IDs. The indexed IDs should be reachable from the
      // allocation IDs, and those allocation IDs used in the traversal
      // path should be the ones that should be replaced with the
      // indexed IDs.
    
      // In order to retain the original ordering of allocation IDs,
      // each indexed logical ID is examined one by one. Specifically,
      // for each of them, we find the corresponding IDs in the current
      // allocation ID vector and replace them with the indexed logical
      // ID.
      auto patched_allocation_ids = allocation_ids;
      for (auto indexed_logical_id : indexed_logical_ids) {
        auto [path, all_visited] = getExprsBetween<IRBFS>(
            {patched_allocation_ids.begin(), patched_allocation_ids.end()},
            {indexed_logical_id},
            /*require_all_to_visited=*/false);
        NVF_ERROR(
            all_visited,
            "Failed to infer valid allocation IDs. Indexed logical IDs need to be entirely allocated but not found in the inferred allocation ID set. Indexed logical ID: ",
            indexed_logical_id->toString(),
    
            ". Allocation IDs: ",
            toDelimitedString(patched_allocation_ids));
    
        auto dependent_allocation_ids = getInputsOfExprPath<IRBFS>(path);
    
        // Insert indexed_logical_id at the innermost position of
        // dependent_allocation_ids.
        int num_dependent_allocation_ids = 0;
        std::vector<IterDomain*> pathched_allocation_ids_next;
        for (auto id : allocation_ids) {
          if (std::find(
                  dependent_allocation_ids.begin(),
                  dependent_allocation_ids.end(),
                  id) != dependent_allocation_ids.end()) {
            ++num_dependent_allocation_ids;
            if (num_dependent_allocation_ids ==
                std::ssize(dependent_allocation_ids)) {
              pathched_allocation_ids_next.push_back(
                  indexed_logical_id->as<IterDomain>());
            }
          } else {
            pathched_allocation_ids_next.push_back(id);
          }
        }
    
        std::swap(patched_allocation_ids, pathched_allocation_ids_next);
      }
    
      return patched_allocation_ids;
    }
    Error Handling

    The function patchAllocationOfIndexedProducerTensor uses NVF_ERROR to handle cases where the indexed logical IDs are not found in the allocation ID set. Consider adding more specific error messages or handling these cases more gracefully.

    all_visited,
    "Failed to infer valid allocation IDs. Indexed logical IDs need to be entirely allocated but not found in the inferred allocation ID set. Indexed logical ID: ",
    indexed_logical_id->toString(),
    
    ". Allocation IDs: ",
    toDelimitedString(patched_allocation_ids));
    Template Overloading

    The addition of template overloads for getInputsOfExpr, getOutputsOfExpr, and getInputsOfExprPath in bfs.h may lead to ambiguity or unexpected behavior. Ensure that these overloads do not conflict with existing functions and that they are necessary.

    template <typename BFSType, typename... AdditionalArgs>
    std::vector<typename BFSType::ValType> getInputsOfExpr(
        const typename BFSType::ExprType& expr,
        Direction dir,
        const AdditionalArgs&... additional_args) {
      return getInputsOfExpr(
          expr,
          dir,
          typename BFSType::InputsType(additional_args...),
          typename BFSType::OutputsType(additional_args...));
    }
    
    template <typename ExprT, typename InputsT, typename OutputsT>
    std::vector<typename GetValType<ExprT>::type> getOutputsOfExpr(
        const ExprT& expr,
        Direction dir,
        InputsT inputs,
        OutputsT outputs) {
      return getInputsOfExpr(expr, reverse(dir), inputs, outputs);
    }
    
    template <typename BFSType, typename... AdditionalArgs>
    std::vector<typename BFSType::ValType> getOutputsOfExpr(
        const typename BFSType::ExprType& expr,
        Direction dir,
        const AdditionalArgs&... additional_args) {
      return getInputsOfExpr<BFSType>(expr, reverse(dir), additional_args...);
    }

    Base automatically changed from alloc_refactor to main March 21, 2025 22:40
    @naoyam naoyam marked this pull request as ready for review March 21, 2025 22:47
    @@ -258,9 +259,6 @@ class AllocationDomainSetup : private kir::IrVisitor {
    for (const auto i : c10::irange(tv->nDims())) {
    auto loop_id = tv->getLoopDomain().at(i);
    auto pt = loop_id->getParallelType();
    if (!mayRequireAllocation(tv, loop_id)) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This would drop broadcast IDs as they wouldn't contribute to the final allocation of the tensor, but they would be still required for the analysis of patchAllocationOfIndexedProducerTensor as it uses IRBFS. Note that such broadcast IDs are eventually removed anyway at the end of this function.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Mar 21, 2025

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Mar 22, 2025

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Mar 23, 2025

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Mar 23, 2025

    !test --diff

    @naoyam naoyam requested a review from zasdfgbnm March 24, 2025 23:07
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Mar 24, 2025

    The codediff changes are expected as the new indexer still lacks the support of magic zero.

    Copy link
    Collaborator

    @zasdfgbnm zasdfgbnm left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM, left some nitpicks

    Comment on lines +773 to +774
    if (std::find(allocation_ids.begin(), allocation_ids.end(), indexed_id) !=
    allocation_ids.end()) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nit:

    !std::ranges::contains(allocation_ids, indexed_id)

    @@ -729,6 +749,93 @@ class AllocationDomainSetup : private kir::IrVisitor {
    return patched_allocation_domains;
    }

    // If a producer tensor is accessed through supplied indices, the
    // indexed logical IDs need to be entirely allocated.
    std::optional<std::vector<IterDomain*>> patchAllocationOfIndexedProducerTensor(
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nit: Do we need std::optional? Could we just return an empty vector?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I thought it might be ambiguous because an empty vector could mean the input allocation domain shouldn't be changed or the allocation domain itself could be empty, so using std::optional would make it more explicit.

    Comment on lines +285 to +287
    NVF_ERROR(std::all_of(contiguity.begin(), contiguity.end(), [](auto b) {
    return b.has_value() && b.value();
    }));
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nit

    NVF_ERROR(std::ranges::all_of(contiguity, [](auto b) {
      return b && *b;
    }));

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is just my personal preference, but I slightly prefer more explicit style.

    Comment on lines +818 to +821
    if (std::find(
    dependent_allocation_ids.begin(),
    dependent_allocation_ids.end(),
    id) != dependent_allocation_ids.end()) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    nit:

    !std::ranges::contains(dependent_allocation_ids, id)

    @naoyam naoyam merged commit 1c67aec into main Mar 26, 2025
    58 of 60 checks passed
    @naoyam naoyam deleted the gather_alloc branch March 26, 2025 14:57
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants