-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
indexAccumulate python api #4066
base: jjsjann123/index_put
Are you sure you want to change the base?
Conversation
Review updated until commit 6b96692 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
marking this as draft to avoid accidental merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to define void handle(IndexAccumulateOp* iaop)
in csrc/python_frontend/translation.cpp
for the python clone and segmentation features?
Otherwise, the PR looks good to me.
py::arg("acc"), | ||
py::arg("index"), | ||
py::arg("value"), | ||
py::return_value_policy::reference); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to improve python user experience by adding a docstring to new functions.
Docstring generated by Gemini.
m.def("index_accumulate", &indexAccumulate,
py::arg("acc_tv"), py::arg("index_tv"), py::arg("value_tv"),
R"(
Accumulates values into a tensor at specified indices.
This function performs a restricted version of `torch.index_put(..., accumulate=true)`.
It adds the values from `value_tv` to the elements of `acc_tv` at the indices
specified by `index_tv`.
acc_tv: The tensor to accumulate into (in-place modification).
index_tv: The tensor containing the indices.
value_tv: The tensor containing the values to accumulate.
Returns:
A pointer to the modified `acc_tv` tensor.
Note:
This is a restricted version and may not support all features of the
full `torch.index_put(..., accumulate=true)` function.
)");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hahaha, thanks for the draft~~~ will add it in.
Things done in this PR is to support embedding backward, which requires
torch.index_put_(..., accumulate=True)
.Stacked PRs:
What this PR does:
Tensor fd.ops.index_accumulate(Tensor acc, Tensor index, Tensor value