Skip to content

Add lowering for insert_slice-like scatter ops (KV-cache) #2771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions stablehlo/conversions/linalg/tests/scatter.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// RUN: stablehlo-opt %s --stablehlo-legalize-to-linalg --split-input-file --canonicalize | FileCheck %s

func.func @matching_update_tensor(%arg0: tensor<1x32x32x128xf32>, %arg1: tensor<1x32x1x128xf32>, %arg2: tensor<1x1xi64>) -> tensor<1x32x32x128xf32> {
// CHECK-NOT: stablehlo.scatter
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK: %[[EXT:.*]] = tensor.extract %arg2[%[[ZERO]], %[[ZERO]]] : tensor<1x1xi64>
// CHECK: %[[IDX:.*]] = arith.index_cast %[[EXT]] : i64 to index
// CHECK: tensor.insert_slice %arg1 into %arg0[0, 0, %[[IDX]], 0] [1, 32, 1, 128] [1, 1, 1, 1] : tensor<1x32x1x128xf32> into tensor<1x32x32x128xf32>
%0 = "stablehlo.scatter"(%arg0, %arg2, %arg1) <{
indices_are_sorted = false,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [0, 1, 3],
inserted_window_dims = [2],
scatter_dims_to_operand_dims = [2],
index_vector_dim = 1>,
unique_indices = false}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
stablehlo.return %arg4 : tensor<f32>
}) : (tensor<1x32x32x128xf32>, tensor<1x1xi64>, tensor<1x32x1x128xf32>) -> tensor<1x32x32x128xf32>
return %0 : tensor<1x32x32x128xf32>


}

// -----

func.func @smaller_update_tensor() -> tensor<9x7x5xf64> {
// CHECK-DAG: %[[scatter_indices:.*]] = tensor.empty() : tensor<1xi32>
// CHECK-DAG: %[[inputs:.*]] = tensor.empty() : tensor<9x[[dim1:.*]]x[[dim0:.*]]xf64>
// CHECK-DAG: %[[updates:.*]] = tensor.empty() : tensor<[[dim1]]x[[dim0]]xf64>
// CHECK-DAG: %[[zero:.*]] = arith.constant 0 : index
%scatter_indices = tensor.empty() : tensor<1xi32>
%inputs = tensor.empty() : tensor<9x7x5xf64>
%updates = tensor.empty() : tensor<7x5xf64>

// CHECK-DAG: %[[ext:.*]] = tensor.extract %[[scatter_indices]][%[[zero]]] : tensor<1xi32>
// CHECK-DAG: %[[idx:.*]] = arith.index_cast %[[ext]] : i32 to index
// CHECK-DAG: %[[inserted_slice:.*]] = tensor.insert_slice %[[updates]] into %[[inputs]][%[[idx]], 0, 0] [1, [[dim1]], [[dim0]]] [1, 1, 1] : tensor<7x5xf64> into tensor<9x7x5xf64>

%3 = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) <{
indices_are_sorted = true,
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [0, 1],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [0]>,
unique_indices = true}> ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg1 : tensor<f64>
}) : (tensor<9x7x5xf64>, tensor<1xi32>, tensor<7x5xf64>) -> tensor<9x7x5xf64>
return %3 : tensor<9x7x5xf64>
}

// -----

func.func @non_matching_scatter(%arg0: tensor<2x3x4x2xi64>, %arg1: tensor<2x2x3x2xi64>, %arg2: tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64> {
// CHECK: stablehlo.scatter
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [3, 4], inserted_window_dims = [1], input_batching_dims = [0], scatter_indices_batching_dims = [1], scatter_dims_to_operand_dims = [2, 1], index_vector_dim = 3>, unique_indices = false}> ({
^bb0(%arg3: tensor<i64>, %arg4: tensor<i64>):
%1 = stablehlo.add %arg3, %arg4 : tensor<i64>
stablehlo.return %1 : tensor<i64>
}) : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
return %0 : tensor<2x3x4x2xi64>
}

// -----

func.func @scatter_with_batching_dims(%input_tensor: tensor<5x200x100x300xf32>,
%scatter_indices: tensor<5x10x2xi32>, %updates: tensor<5x10x300xf32>) ->
tensor<5x200x100x300xf32> {
// CHECK: stablehlo.scatter
%0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%add = stablehlo.add %lhs, %rhs : tensor<f32>
"stablehlo.return"(%add) : (tensor<f32>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2],
inserted_window_dims = [1, 2],
input_batching_dims = [0],
scatter_indices_batching_dims = [0],
scatter_dims_to_operand_dims = [1, 2],
index_vector_dim = 2
>,
indices_are_sorted = true,
unique_indices = true
} : (tensor<5x200x100x300xf32>, tensor<5x10x2xi32>, tensor<5x10x300xf32>) ->
tensor<5x200x100x300xf32>
func.return %0 : tensor<5x200x100x300xf32>
}

// -----

func.func @valid_scatter_dimensions_with_dynamic_index_vector_dim(
%input_tensor: tensor<?x?x?xf32>, %scatter_indices: tensor<10x?xi32>,
%updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
%0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%add = stablehlo.add %lhs, %rhs : tensor<f32>
"stablehlo.return"(%add) : (tensor<f32>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [1],
inserted_window_dims = [0, 1],
scatter_dims_to_operand_dims = [0, 1, 2],
index_vector_dim = 1
>,
indices_are_sorted = true,
unique_indices = true
} : (tensor<?x?x?xf32>, tensor<10x?xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
func.return %0 : tensor<?x?x?xf32>
}
1 change: 1 addition & 0 deletions stablehlo/conversions/linalg/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_mlir_library(StablehloLinalgTransforms
StablehloToLinalgPointwise.cpp
StablehloToLinalgRandom.cpp
StablehloToLinalgReduce.cpp
StablehloToLinalgScatter.cpp
TypeConversion.cpp

DEPENDS
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/conversions/linalg/transforms/Rewriters.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ void populateStablehloReductionToLinalgConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns, bool enablePrimitiveOps);

/// Populates the patterns that convert from scatter StableHLO ops to Linalg
/// on tensors.
void populateStablehloScatterToLinalgConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns, bool enablePrimitiveOps);

/// Populates the patterns that convert scalar StableHLO ops to Arith ops.
void populateScalarHloToArithConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,8 @@ void populateStablehloToLinalgConversionPatterns(MLIRContext *context,
context, typeConverter, patterns);
detail::populateStablehloReductionToLinalgConversionPatterns(
context, typeConverter, patterns, enablePrimitiveOps);
detail::populateStablehloScatterToLinalgConversionPatterns(
context, typeConverter, patterns, enablePrimitiveOps);
detail::populateScalarHloToArithConversionPatterns(
context, typeConverter, patterns, isInBodyOfLinalgOps);
linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns);
Expand Down
204 changes: 204 additions & 0 deletions stablehlo/conversions/linalg/transforms/StablehloToLinalgScatter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/* Copyright 2019 The IREE Authors
Copyright 2023 OpenXLA Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Implements logic for lowering StableHLO scatter ops to Linalg dialect.
// These patterns are separated out to their own file to save on the compilation
// times.

#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/conversions/linalg/transforms/Rewriters.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::stablehlo {
namespace {
bool isAssignment(stablehlo::ScatterOp op) {
// Return true if the scatter op is equivalent to an assignment.
// This means that there is only one op in the body, and it is a ReturnOp.
// E.g.,
// update_function =
// ^bb0(%arg0: T, %arg1: T):
// return %arg1 : T
// })
Region &region = op.getUpdateComputation();
Block &block = region.front();
bool oneOperation = block.begin() == --block.end();
if (!oneOperation) {
return false;
}

stablehlo::ReturnOp returnOp =
dyn_cast<stablehlo::ReturnOp>(block.getTerminator());
if (!returnOp) {
return false;
}

return returnOp.getOperands().front() == block.getArgument(1);
}

bool singleFullSlices(stablehlo::ScatterOp op) {
// Return true if the scatter op is inserting the whole update tensor into the
// input tensor. This means that all dims that are not in the
// update_window_dims are size 1.

auto update = op.getUpdates().front();
auto updateTy = dyn_cast<RankedTensorType>(update.getType());
if (!updateTy || !updateTy.hasStaticShape()) {
return false; // Can't verify without static shape
}

auto scatterDimNumbers = op.getScatterDimensionNumbers();
auto updateWindowDims = scatterDimNumbers.getUpdateWindowDims();

llvm::SmallDenseSet<int64_t> windowDimsSet(updateWindowDims.begin(),
updateWindowDims.end());

auto shape = updateTy.getShape();
for (int64_t i = 0; i < static_cast<int64_t>(shape.size()); ++i) {
if (!windowDimsSet.contains(i)) {
if (shape[i] != 1) {
// Found a non-window dimension that is not size-1
return false;
}
}
}
return true;
}

bool isInsertSliceScatter(stablehlo::ScatterOp op) {
// Return true if the scatter op is equivalent to an insert_slice

// Requirement 1: has exactly one input, one update and one result tensor
if (op.getInputs().size() != 1 || op.getUpdates().size() != 1 ||
op.getResults().size() != 1) {
return false;
}

// Requirement 2: is assignment (see isAssignment)
if (!isAssignment(op)) {
return false;
}

// Requirement 3: no batching
// input_batching_dims = []
// scatter_indices_batching_dims = []
auto scatterDimNumbers = op.getScatterDimensionNumbers();
if (!scatterDimNumbers.getInputBatchingDims().empty()) {
return false;
}

// Requirement 4: we are inserting the whole %update into a dimension of
// %input
if (!singleFullSlices(op)) {
return false;
}

// Requirement 5: scatter indices is a static tensor of size 1
auto indicesType = cast<RankedTensorType>(op.getScatterIndices().getType());
if (!indicesType.hasStaticShape() || indicesType.getNumElements() != 1) {
return false;
}

return true;
}

/// Pattern to lower relevant stablehlo::ScatterOps to tensor.insert_slice ops
struct ReduceOpToInsertSliceConverter final
: public OpConversionPattern<stablehlo::ScatterOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
stablehlo::ScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isInsertSliceScatter(op)) {
return failure();
}

auto input = op.getInputs().front();
auto update = op.getUpdates().front();
auto scatterIndices = op.getScatterIndices();

auto inputTy = cast<RankedTensorType>(input.getType());
auto updateTy = cast<RankedTensorType>(update.getType());
auto inputShape = inputTy.getShape();
auto updateShape = updateTy.getShape();

auto scatterDimNumbers = op.getScatterDimensionNumbers();
auto insertedWindowDims = scatterDimNumbers.getInsertedWindowDims();

SmallVector<Value> dynOffsets, dynSizes, dynStrides;
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
Location loc = op.getLoc();
bool sameRank = inputTy.getRank() == updateTy.getRank();

for (size_t i = 0, updateDim = 0; i < inputShape.size(); i++) {
if (llvm::is_contained(insertedWindowDims, i)) {
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
int64_t rank = cast<ShapedType>(scatterIndices.getType()).getRank();
SmallVector<Value> indices;
for (int64_t i = 0; i < rank; ++i) {
indices.push_back(zero);
}
auto extractOp =
rewriter.create<tensor::ExtractOp>(loc, scatterIndices, indices);
auto indexCastOp = rewriter
.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), extractOp)
.getResult();

// Offset is dynamic, based on the index we extract
dynOffsets.push_back(indexCastOp);
staticOffsets.push_back(ShapedType::kDynamic);
staticSizes.push_back(1);
if (sameRank) {
if (updateShape[updateDim] != 1) {
op->emitError(llvm::formatv("updateShape[{0}] must be 1, got {1}",
updateDim, updateShape[updateDim]));
}
updateDim++;
}

} else {
staticOffsets.push_back(0);
staticSizes.push_back(updateShape[updateDim]);
updateDim++;
}
staticStrides.push_back(1);
}

rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
op, update, input, dynOffsets, dynSizes, dynStrides, staticOffsets,
staticSizes, staticStrides);
return success();
}
};
} // namespace

namespace detail {
void populateStablehloScatterToLinalgConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns, bool enablePrimitiveOps) {
// Ensure specialized patterns are higher priority than their generic
// versions.
patterns->add<ReduceOpToInsertSliceConverter>(typeConverter, context,
PatternBenefit(2));
}
} // namespace detail
} // namespace mlir::stablehlo