Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

indexAccumulate python api #4066

Draft
wants to merge 7 commits into
base: jjsjann123/index_put
Choose a base branch
from

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Mar 12, 2025

Things done in this PR is to support embedding backward, which requires torch.index_put_(..., accumulate=True).

Stacked PRs:

What this PR does:

  • Added python API Tensor fd.ops.index_accumulate(Tensor acc, Tensor index, Tensor value

Copy link

github-actions bot commented Mar 12, 2025

Review updated until commit 6b96692

Description

  • Added index_accumulate Python API

  • Included opinfo test for index_accumulate

  • Updated fusion record handling for index_accumulate

  • Formatted code with clang-format and BLACK


Changes walkthrough 📝

Relevant files
Enhancement
python_bindings.cpp
Add index_accumulate Python API                                                   

csrc/python_frontend/python_bindings.cpp

  • Added index_accumulate function to Python bindings
+24/-0   
fusion_record.cpp
Add deserialization for IndexAccumulateOpRecord                   

csrc/serde/fusion_record.cpp

  • Added deserialization for IndexAccumulateOpRecord
+7/-0     
fusion_record.h
Add IndexAccumulateOpRecord                                                           

csrc/python_frontend/fusion_record.h

  • Added IndexAccumulateOpRecord struct
+22/-0   
Tests
opinfo_input_generators.py
Add index_accumulate_generator                                                     

tests/python/opinfo_input_generators.py

  • Added index_accumulate_generator function
+19/-0   
opinfos.py
Add index_accumulate opinfo                                                           

tests/python/opinfos.py

  • Added index_accumulate_generator to opinfo list
  • Added index_accumulate_ref function
  • Created index_accumulate_opinfo and appended to shape_ops
  • +26/-0   
    Configuration changes
    fusion_cache.fbs
    Add IndexAccumulateOp to RecordType                                           

    csrc/serde/fusion_cache.fbs

    • Added IndexAccumulateOp to RecordType enum
    +1/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Goal

    Ensure a clear performance goal is set and that feedback was sought early regarding the expected performance improvements of the index_accumulate function.

    "index_accumulate",
    [](FusionDefinition::Operators& self,
       Tensor acc,
       Tensor index,
       Tensor value) -> Tensor {
      FUSER_PERF_SCOPE("Operators.index_accumulate");
      NVF_CHECK(
          self.validUse(), "Attempting to add to a completed definition!");
      FusionDefinition* fd = self.fusion_definition;
      Tensor output = fd->defineTensor(acc.dims);
      fd->defineRecord(new IndexAccumulateOpRecord(
          {
              fd->recordingState(acc()),
              fd->recordingState(index()),
              fd->recordingState(value()),
          },
          {fd->recordingState(output())}));
      return output;
    },
    py::arg("acc"),
    py::arg("index"),
    py::arg("value"),
    py::return_value_policy::reference);
    Test Coverage

    Verify that the test cases in index_accumulate_generator cover a wide range of scenarios and edge cases to ensure the correctness and robustness of the index_accumulate function.

    def index_accumulate_generator(
        op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs
    ):
        make_arg = partial(
            make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad
        )
        make_index = partial(make_tensor, device="cuda", requires_grad=False)
    
        # vocab_size, hidden_size, seq_size
        cases = ((1024, 12, 300),)
    
        for vocab, hidden, seq in cases:
            for index_dtype in [torch.int, torch.long]:
                acc = make_arg((vocab, hidden))
                index = make_index((seq,), low=0, high=vocab, dtype=index_dtype)
                value = make_arg((seq, hidden))
                yield SampleInput(acc, index, value)
    Error Handling

    Ensure that the IndexAccumulateOpRecord operator handles potential errors gracefully, such as mismatched tensor dimensions or unsupported data types.

    IndexAccumulateOpRecord(std::vector<State> args, std::vector<State> outputs)
        : RecordFunctor(
              std::move(args),
              std::move(outputs),
              "ops.index_accumulate",
              serde::RecordType::IndexAccumulateOp) {}
    ~IndexAccumulateOpRecord() override = default;
    RecordFunctor* clone() final {
      return new IndexAccumulateOpRecord(*this);
    }
    
    void operator()(FusionState& fd) final {
      auto acc = fd.getFusionState(args_.at(0).index)->as<TensorView>();
      auto index = fd.getFusionState(args_.at(1).index)->as<TensorView>();
      auto value = fd.getFusionState(args_.at(2).index)->as<TensorView>();
    
      auto output = indexAccumulate(acc, index, value);
      fd.setFusionState(outputs_.at(0).index, output);
    }

    @jjsjann123 jjsjann123 mentioned this pull request Mar 14, 2025
    2 tasks
    @jjsjann123 jjsjann123 marked this pull request as ready for review March 14, 2025 01:18
    @jjsjann123 jjsjann123 requested review from rdspring1 and protonu March 14, 2025 01:19
    @jjsjann123 jjsjann123 marked this pull request as draft March 14, 2025 01:19
    @jjsjann123
    Copy link
    Collaborator Author

    marking this as draft to avoid accidental merge.
    But this PR is good for review as-is.

    Copy link
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do you need to define void handle(IndexAccumulateOp* iaop) in csrc/python_frontend/translation.cpp for the python clone and segmentation features?

    Otherwise, the PR looks good to me.

    py::arg("acc"),
    py::arg("index"),
    py::arg("value"),
    py::return_value_policy::reference);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'm trying to improve python user experience by adding a docstring to new functions.

    Docstring generated by Gemini.

        m.def("index_accumulate", &indexAccumulate,
              py::arg("acc_tv"), py::arg("index_tv"), py::arg("value_tv"),
              R"(
            Accumulates values into a tensor at specified indices.
    
            This function performs a restricted version of `torch.index_put(..., accumulate=true)`.
            It adds the values from `value_tv` to the elements of `acc_tv` at the indices
            specified by `index_tv`.
    
            acc_tv: The tensor to accumulate into (in-place modification).
            index_tv: The tensor containing the indices.
            value_tv: The tensor containing the values to accumulate.
    
            Returns:
                A pointer to the modified `acc_tv` tensor.
    
            Note:
                This is a restricted version and may not support all features of the
                full `torch.index_put(..., accumulate=true)` function.
        )");

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Hahaha, thanks for the draft~~~ will add it in.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants