Skip to content

Commit 85fe3af

Browse files
authored
Added stablehlo divide fp to tosa reciprocal+mul conversion (#2873)
According to [TOSA documentation](https://mlir.llvm.org/docs/Dialects/TOSA/#tosaintdiv-mlirtosaintdivop) for FP division "_Floating point divide should use RECIPROCAL and MUL_": This PR implements conversion form stablehlo.divide into combination of TOSA ops.
1 parent 96acdcb commit 85fe3af

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

stablehlo/conversions/tosa/tests/binary.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi
5050
return %0 : tensor<10xi32>
5151
}
5252

53+
// CHECK-LABEL: @divide_f32
54+
func.func @divide_f32(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10xf32> {
55+
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}
56+
// CHECK-DAG: %[[VAR1:.*]] = tosa.reciprocal %arg1
57+
// CHECK: tosa.mul %arg0, %[[VAR1]], %[[VAR0]]
58+
%0 = "stablehlo.divide"(%arg0, %arg1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
59+
return %0 : tensor<10xf32>
60+
}
61+
5362
// CHECK-LABEL: @dot_vector_vector
5463
func.func @dot_vector_vector(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> tensor<f32> {
5564
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0>

stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,41 @@ struct ConvertStablehloReshapeOp
523523
}
524524
};
525525

526+
struct ConvertStablehloFloatDivideOp
527+
: public OpRewritePattern<stablehlo::DivOp> {
528+
using OpRewritePattern<stablehlo::DivOp>::OpRewritePattern;
529+
530+
LogicalResult matchAndRewrite(stablehlo::DivOp op,
531+
PatternRewriter& rewriter) const override {
532+
auto lhsType = dyn_cast<RankedTensorType>(op.getLhs().getType());
533+
auto rhsType = dyn_cast<RankedTensorType>(op.getRhs().getType());
534+
if (!lhsType || !rhsType) {
535+
return rewriter.notifyMatchFailure(op, "expected ranked tensor types");
536+
}
537+
538+
if (!llvm::isa<mlir::FloatType>(lhsType.getElementType()) &&
539+
!llvm::isa<mlir::FloatType>(rhsType.getElementType())) {
540+
return rewriter.notifyMatchFailure(
541+
op, "only converts floating point division");
542+
}
543+
544+
auto shiftTensorType = RankedTensorType::get({1}, rewriter.getI8Type());
545+
auto zeroShiftValue = DenseElementsAttr::get(
546+
shiftTensorType, rewriter.getIntegerAttr(rewriter.getI8Type(), 0));
547+
auto shiftConst = tosa::ConstOp::create(rewriter, op.getLoc(),
548+
shiftTensorType, zeroShiftValue);
549+
550+
auto reciprocalOp =
551+
tosa::ReciprocalOp::create(rewriter, op.getLoc(), rhsType, op.getRhs());
552+
553+
auto mulOp = tosa::MulOp::create(rewriter, op.getLoc(), op.getType(),
554+
op.getLhs(), reciprocalOp, shiftConst);
555+
556+
rewriter.replaceOp(op, mulOp.getResult());
557+
return success();
558+
}
559+
};
560+
526561
LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
527562
RewritePatternSet patternList(ctx);
528563
populateGeneratedPDLLPatterns(patternList);
@@ -543,6 +578,8 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
543578
patternList.addWithLabel<ConvertStablehloWhileOp>({"StablehloWhile"}, ctx);
544579
patternList.addWithLabel<ConvertStablehloReshapeOp>({"StablehloReshape"},
545580
ctx);
581+
patternList.addWithLabel<ConvertStablehloFloatDivideOp>(
582+
{"StablehloFloatDivide"}, ctx);
546583

547584
patterns = std::move(patternList);
548585
return success();

0 commit comments

Comments
 (0)