Skip to content

Commit 96acdcb

Browse files
authored
1 parent 3acda59 commit 96acdcb

20 files changed

+972
-82
lines changed

BUILD.bazel

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,24 @@ gentbl_cc_library(
11071107
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
11081108
)
11091109

1110+
cc_library(
1111+
name = "stablehlo_broadcast_lowering",
1112+
srcs = [
1113+
"stablehlo/transforms/StablehloBroadcastLowering.cpp",
1114+
],
1115+
hdrs = [
1116+
"stablehlo/transforms/StablehloBroadcastLowering.h",
1117+
],
1118+
strip_include_prefix = ".",
1119+
deps = [
1120+
":stablehlo_ops",
1121+
"@llvm-project//llvm:Support",
1122+
"@llvm-project//mlir:IR",
1123+
"@llvm-project//mlir:ShapeDialect",
1124+
"@llvm-project//mlir:Support",
1125+
],
1126+
)
1127+
11101128
cc_library(
11111129
name = "stablehlo_pass_utils",
11121130
srcs = [

WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "42a8ff877d47131ecb1280a1cc7e5e3c3bca6952"
20+
LLVM_COMMIT = "2bc22ea02edda5926f3e53f141def9bf212ac1db"
2121

22-
LLVM_SHA256 = "f768c5c3b987f68318b8ab3dd4530e54988dfe7d6bfb9b7c9c96acf503367d50"
22+
LLVM_SHA256 = "4a034eda852b3c2d448d38e8661cbac45ae2233a29defeb55913fa5205cd29f7"
2323

2424
http_archive(
2525
name = "llvm-raw",

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
42a8ff877d47131ecb1280a1cc7e5e3c3bca6952
1+
2bc22ea02edda5926f3e53f141def9bf212ac1db

stablehlo/dialect/StablehloOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3275,12 +3275,12 @@ Attribute StablehloDialect::parseAttribute(DialectAsmParser& parser,
32753275
// Entry point for Attribute printing, TableGen generated code will handle the
32763276
// dispatch to the individual classes.
32773277
void StablehloDialect::printAttribute(Attribute attr,
3278-
DialectAsmPrinter& os) const {
3278+
DialectAsmPrinter& printer) const {
32793279
if (auto type_extensions = dyn_cast<TypeExtensionsAttr>(attr)) {
3280-
hlo::printTypeExtensions(cast<hlo::BoundedAttrInterface>(attr), os);
3280+
hlo::printTypeExtensions(cast<hlo::BoundedAttrInterface>(attr), printer);
32813281
return;
32823282
}
3283-
LogicalResult result = generatedAttributePrinter(attr, os);
3283+
LogicalResult result = generatedAttributePrinter(attr, printer);
32843284
(void)result;
32853285
assert(succeeded(result));
32863286
}

stablehlo/dialect/StablehloOps.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,14 @@ class StablehloDialect : public Dialect {
9393
Type parseType(DialectAsmParser& parser) const override;
9494

9595
// Prints a type registered to this dialect.
96-
void printType(Type type, DialectAsmPrinter& os) const override;
96+
void printType(Type type, DialectAsmPrinter& printer) const override;
9797

9898
// Parses an attribute registered to this dialect.
9999
Attribute parseAttribute(DialectAsmParser& parser, Type type) const override;
100100

101101
// Prints an attribute registered to this dialect.
102-
void printAttribute(Attribute attr, DialectAsmPrinter& os) const override;
102+
void printAttribute(Attribute attr,
103+
DialectAsmPrinter& printer) const override;
103104

104105
// Get the set dialect version.
105106
std::optional<StablehloDialectVersion> getVersion() const;

stablehlo/integrations/python/mlir/dialects/InterpreterOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ limitations under the License.
1717
#ifndef STABLEHLO_INTEGRATIONS_PYTHON_INTERPRETER_OPS
1818
#define STABLEHLO_INTEGRATIONS_PYTHON_INTERPRETER_OPS
1919

20-
include "third_party/stablehlo/stablehlo/reference/InterpreterOps.h"
20+
include "stablehlo/reference/InterpreterOps.h"
2121

2222
#endif

stablehlo/tests/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ cc_library(
102102
deps = [
103103
":test_utils_inc_gen",
104104
"//:stablehlo_assembly_format",
105+
"//:stablehlo_broadcast_lowering",
106+
"//:stablehlo_ops",
105107
"@llvm-project//llvm:Support",
106108
"@llvm-project//mlir:FuncDialect",
107109
"@llvm-project//mlir:IR",

stablehlo/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ add_mlir_library(StablehloTestUtils
4848
MLIRSupport
4949
MLIRTransformUtils
5050
StablehloAssemblyFormat
51+
StablehloBroadcastLowering
5152
)
5253

5354
set(LLVM_TARGET_DEFINITIONS CheckOps.td)

stablehlo/tests/TestUtils.cpp

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <utility>
2020

2121
#include "llvm/ADT/STLExtras.h"
22+
#include "llvm/ADT/SmallVector.h"
2223
#include "llvm/Support/Casting.h"
2324
#include "mlir/Dialect/Func/IR/FuncOps.h"
2425
#include "mlir/Dialect/Shape/IR/Shape.h"
@@ -28,26 +29,50 @@ limitations under the License.
2829
#include "mlir/IR/Operation.h"
2930
#include "mlir/IR/OperationSupport.h"
3031
#include "mlir/IR/PatternMatch.h"
32+
#include "mlir/IR/TypeRange.h"
3133
#include "mlir/Interfaces/InferTypeOpInterface.h"
3234
#include "mlir/Interfaces/SideEffectInterfaces.h"
3335
#include "mlir/Pass/Pass.h"
3436
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
3537
#include "mlir/Support/LLVM.h"
3638
#include "mlir/Support/LogicalResult.h"
3739
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
40+
#include "stablehlo/dialect/StablehloOps.h"
41+
#include "stablehlo/transforms/StablehloBroadcastLowering.h"
3842

3943
namespace mlir {
4044
namespace hlo {
4145

4246
namespace {
4347

48+
struct BroadcastValuesPattern : public RewritePattern {
49+
explicit BroadcastValuesPattern(MLIRContext* context)
50+
: RewritePattern("hlo_test_broadcast.numpy_broadcast", 1, context) {}
51+
LogicalResult matchAndRewrite(Operation* op,
52+
PatternRewriter& rewriter) const override {
53+
// Process all operands
54+
SmallVector<Value> operands = llvm::to_vector(op->getOperands());
55+
auto broadcastedOperands =
56+
stablehlo::numpyBroadcastIfNeeded(rewriter, operands);
57+
if (failed(broadcastedOperands)) return failure();
58+
59+
// Replace with custom call to avoid pattern reapplication
60+
auto customCall = stablehlo::CustomCallOp::create(
61+
rewriter, op->getLoc(), op->getResultTypes(), *broadcastedOperands);
62+
customCall.setCallTargetName("numpy_broadcasted");
63+
customCall.setHasSideEffect(true);
64+
rewriter.replaceOp(op, customCall);
65+
return success();
66+
}
67+
};
68+
4469
struct InferReturnTypesPattern : public RewritePattern {
45-
explicit InferReturnTypesPattern(MLIRContext *context)
70+
explicit InferReturnTypesPattern(MLIRContext* context)
4671
: RewritePattern("hlo_test_infer.get_return_types", 1, context) {}
47-
LogicalResult matchAndRewrite(Operation *op,
48-
PatternRewriter &rewriter) const override {
72+
LogicalResult matchAndRewrite(Operation* op,
73+
PatternRewriter& rewriter) const override {
4974
if (op->getNumOperands() != 1) return failure();
50-
auto *definingOp = op->getOperand(0).getDefiningOp();
75+
auto* definingOp = op->getOperand(0).getDefiningOp();
5176
auto definingOpInt =
5277
llvm::dyn_cast_or_null<InferTypeOpInterface>(definingOp);
5378
if (!definingOpInt) return failure();
@@ -62,8 +87,8 @@ struct InferReturnTypesPattern : public RewritePattern {
6287
OperationState state(op->getLoc(), "hlo_test_infer.return_types",
6388
op->getOperands(), op->getResultTypes(),
6489
op->getAttrs());
65-
auto *newOp = rewriter.create(state);
66-
for (const auto &it : llvm::enumerate(types))
90+
auto* newOp = rewriter.create(state);
91+
for (const auto& it : llvm::enumerate(types))
6792
newOp->setAttr((StringRef("types") + Twine(it.index())).str(),
6893
TypeAttr::get(it.value()));
6994
rewriter.replaceOp(op, {newOp->getResults()});
@@ -72,10 +97,10 @@ struct InferReturnTypesPattern : public RewritePattern {
7297
};
7398

7499
struct ReifyReturnTypeShapesPattern : public RewritePattern {
75-
explicit ReifyReturnTypeShapesPattern(MLIRContext *context)
100+
explicit ReifyReturnTypeShapesPattern(MLIRContext* context)
76101
: RewritePattern("hlo_test_infer.reify_return_type_shapes", 1, context) {}
77-
LogicalResult matchAndRewrite(Operation *op,
78-
PatternRewriter &rewriter) const override {
102+
LogicalResult matchAndRewrite(Operation* op,
103+
PatternRewriter& rewriter) const override {
79104
if (op->getNumOperands() != 1) return failure();
80105
auto definingOp =
81106
op->getOperand(0).getDefiningOp<InferShapedTypeOpInterface>();
@@ -89,7 +114,7 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
89114
}
90115
};
91116

92-
LogicalResult checkSpeculatability(PatternRewriter &rewriter, Operation *op,
117+
LogicalResult checkSpeculatability(PatternRewriter& rewriter, Operation* op,
93118
mlir::Speculation::Speculatability spec) {
94119
if (op->getNumOperands() != 1) return failure();
95120
auto definingOp =
@@ -106,67 +131,86 @@ LogicalResult checkSpeculatability(PatternRewriter &rewriter, Operation *op,
106131
}
107132

108133
struct IsSpeculatablePattern : public RewritePattern {
109-
explicit IsSpeculatablePattern(MLIRContext *context)
134+
explicit IsSpeculatablePattern(MLIRContext* context)
110135
: RewritePattern("hlo_test_speculatability.is_speculatable", 1, context) {
111136
}
112-
LogicalResult matchAndRewrite(Operation *op,
113-
PatternRewriter &rewriter) const override {
137+
LogicalResult matchAndRewrite(Operation* op,
138+
PatternRewriter& rewriter) const override {
114139
return checkSpeculatability(rewriter, op, mlir::Speculation::Speculatable);
115140
}
116141
};
117142

118143
struct IsRecursivelySpeculatablePattern : public RewritePattern {
119-
explicit IsRecursivelySpeculatablePattern(MLIRContext *context)
144+
explicit IsRecursivelySpeculatablePattern(MLIRContext* context)
120145
: RewritePattern("hlo_test_speculatability.is_recursively_speculatable",
121146
1, context) {}
122-
LogicalResult matchAndRewrite(Operation *op,
123-
PatternRewriter &rewriter) const override {
147+
LogicalResult matchAndRewrite(Operation* op,
148+
PatternRewriter& rewriter) const override {
124149
return checkSpeculatability(rewriter, op,
125150
mlir::Speculation::RecursivelySpeculatable);
126151
}
127152
};
128153

129154
struct IsNotSpeculatablePattern : public RewritePattern {
130-
explicit IsNotSpeculatablePattern(MLIRContext *context)
155+
explicit IsNotSpeculatablePattern(MLIRContext* context)
131156
: RewritePattern("hlo_test_speculatability.is_not_speculatable", 1,
132157
context) {}
133-
LogicalResult matchAndRewrite(Operation *op,
134-
PatternRewriter &rewriter) const override {
158+
LogicalResult matchAndRewrite(Operation* op,
159+
PatternRewriter& rewriter) const override {
135160
return checkSpeculatability(rewriter, op,
136161
mlir::Speculation::NotSpeculatable);
137162
}
138163
};
139164

165+
#define GEN_PASS_DEF_HLOTESTBROADCASTPASS
140166
#define GEN_PASS_DEF_HLOTESTINFERPASS
141167
#define GEN_PASS_DEF_HLOTESTSPECULATABILITYPASS
142168
#include "stablehlo/tests/TestUtils.h.inc"
143169

170+
struct HloTestBroadcastPass
171+
: public impl::HloTestBroadcastPassBase<HloTestBroadcastPass> {
172+
LogicalResult initialize(MLIRContext* context) override {
173+
RewritePatternSet patterns(context);
174+
patterns.add<BroadcastValuesPattern>(context);
175+
patterns_ = std::move(patterns);
176+
return success();
177+
}
178+
179+
void runOnOperation() override {
180+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns_))))
181+
return signalPassFailure();
182+
}
183+
184+
private:
185+
FrozenRewritePatternSet patterns_;
186+
};
187+
144188
struct HloTestInferPass : public impl::HloTestInferPassBase<HloTestInferPass> {
145-
LogicalResult initialize(MLIRContext *context) override {
146-
RewritePatternSet patterns_(context);
147-
patterns_.add<InferReturnTypesPattern>(context);
148-
patterns_.add<ReifyReturnTypeShapesPattern>(context);
149-
patterns = std::move(patterns_);
189+
LogicalResult initialize(MLIRContext* context) override {
190+
RewritePatternSet patterns(context);
191+
patterns.add<InferReturnTypesPattern>(context);
192+
patterns.add<ReifyReturnTypeShapesPattern>(context);
193+
patterns_ = std::move(patterns);
150194
return success();
151195
}
152196

153197
void runOnOperation() override {
154-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
198+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns_))))
155199
return signalPassFailure();
156200
}
157201

158202
private:
159-
FrozenRewritePatternSet patterns;
203+
FrozenRewritePatternSet patterns_;
160204
};
161205

162206
struct HloTestSpeculatabilityPass
163207
: public impl::HloTestSpeculatabilityPassBase<HloTestSpeculatabilityPass> {
164-
LogicalResult initialize(MLIRContext *context) override {
165-
RewritePatternSet patterns_(context);
166-
patterns_.add<IsSpeculatablePattern>(context);
167-
patterns_.add<IsNotSpeculatablePattern>(context);
168-
patterns_.add<IsRecursivelySpeculatablePattern>(context);
169-
patterns = std::move(patterns_);
208+
LogicalResult initialize(MLIRContext* context) override {
209+
RewritePatternSet patterns(context);
210+
patterns.add<IsSpeculatablePattern>(context);
211+
patterns.add<IsNotSpeculatablePattern>(context);
212+
patterns.add<IsRecursivelySpeculatablePattern>(context);
213+
patterns_ = std::move(patterns);
170214
return success();
171215
}
172216

@@ -175,11 +219,11 @@ struct HloTestSpeculatabilityPass
175219
config.setMaxIterations(1)
176220
.setUseTopDownTraversal(true)
177221
.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled);
178-
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
222+
(void)applyPatternsGreedily(getOperation(), std::move(patterns_));
179223
}
180224

181225
private:
182-
FrozenRewritePatternSet patterns;
226+
FrozenRewritePatternSet patterns_;
183227
};
184228

185229
#define GEN_PASS_REGISTRATION

stablehlo/tests/TestUtils.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ limitations under the License.
1616

1717
include "mlir/Pass/PassBase.td"
1818

19+
def HloTestBroadcastPass : Pass<"hlo-test-broadcast", "func::FuncOp"> {
20+
let summary = "Uses test ops to invoke BroadcastUtils methods.";
21+
let dependentDialects = ["stablehlo::StablehloDialect"];
22+
}
23+
1924
def HloTestInferPass : Pass<"hlo-test-infer", "func::FuncOp"> {
2025
let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods.";
2126
let dependentDialects = ["shape::ShapeDialect"];

0 commit comments

Comments
 (0)