Add lowering for insert_slice
-like scatter
ops (KV-cache)
#2771
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Coming from #1758, this PR adds a lowering for a narrow case of
stablehlo.scatter
, namely those that are equivalent to atensor.insert_slice
.This is a common case, since it's how KV-cache updates are modelled when exporting from PyTorch.
The code is adapted from an implementation in the catalyst compiler by @erick-xanadu, expanded to cover the cases I saw coming out of PyTorch.
The conversion produces
tensor.insert_slice
ops, rather than linalg. This may or may not be acceptable, but I'm putting the PR up first to get thoughts.I don't believe there's an exact equivalent to
insert_slice
in linalg, but I think it could be achieved with alinalg.generic
. However, since the StableHLO conversion inserts a bunch oftensor
ops anyway, I don't think lowering totensor.insert_slice
directly is against the spirit of what already exists.All other cases of scatter will be left as-is, but given it's quite a complex op, this pattern will provide near-term utility while a more general purpose lowering is cooked up.
Note to reviewers, this is my first contribution to the project, so there may some workflow things I've missed.