@@ -19,6 +19,7 @@ limitations under the License.
1919#include < utility>
2020
2121#include " llvm/ADT/STLExtras.h"
22+ #include " llvm/ADT/SmallVector.h"
2223#include " llvm/Support/Casting.h"
2324#include " mlir/Dialect/Func/IR/FuncOps.h"
2425#include " mlir/Dialect/Shape/IR/Shape.h"
@@ -28,26 +29,50 @@ limitations under the License.
2829#include " mlir/IR/Operation.h"
2930#include " mlir/IR/OperationSupport.h"
3031#include " mlir/IR/PatternMatch.h"
32+ #include " mlir/IR/TypeRange.h"
3133#include " mlir/Interfaces/InferTypeOpInterface.h"
3234#include " mlir/Interfaces/SideEffectInterfaces.h"
3335#include " mlir/Pass/Pass.h"
3436#include " mlir/Rewrite/FrozenRewritePatternSet.h"
3537#include " mlir/Support/LLVM.h"
3638#include " mlir/Support/LogicalResult.h"
3739#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
40+ #include " stablehlo/dialect/StablehloOps.h"
41+ #include " stablehlo/transforms/StablehloBroadcastLowering.h"
3842
3943namespace mlir {
4044namespace hlo {
4145
4246namespace {
4347
48+ struct BroadcastValuesPattern : public RewritePattern {
49+ explicit BroadcastValuesPattern (MLIRContext* context)
50+ : RewritePattern(" hlo_test_broadcast.numpy_broadcast" , 1 , context) {}
51+ LogicalResult matchAndRewrite (Operation* op,
52+ PatternRewriter& rewriter) const override {
53+ // Process all operands
54+ SmallVector<Value> operands = llvm::to_vector (op->getOperands ());
55+ auto broadcastedOperands =
56+ stablehlo::numpyBroadcastIfNeeded (rewriter, operands);
57+ if (failed (broadcastedOperands)) return failure ();
58+
59+ // Replace with custom call to avoid pattern reapplication
60+ auto customCall = stablehlo::CustomCallOp::create (
61+ rewriter, op->getLoc (), op->getResultTypes (), *broadcastedOperands);
62+ customCall.setCallTargetName (" numpy_broadcasted" );
63+ customCall.setHasSideEffect (true );
64+ rewriter.replaceOp (op, customCall);
65+ return success ();
66+ }
67+ };
68+
4469struct InferReturnTypesPattern : public RewritePattern {
45- explicit InferReturnTypesPattern (MLIRContext * context)
70+ explicit InferReturnTypesPattern (MLIRContext* context)
4671 : RewritePattern(" hlo_test_infer.get_return_types" , 1 , context) {}
47- LogicalResult matchAndRewrite (Operation * op,
48- PatternRewriter & rewriter) const override {
72+ LogicalResult matchAndRewrite (Operation* op,
73+ PatternRewriter& rewriter) const override {
4974 if (op->getNumOperands () != 1 ) return failure ();
50- auto * definingOp = op->getOperand (0 ).getDefiningOp ();
75+ auto * definingOp = op->getOperand (0 ).getDefiningOp ();
5176 auto definingOpInt =
5277 llvm::dyn_cast_or_null<InferTypeOpInterface>(definingOp);
5378 if (!definingOpInt) return failure ();
@@ -62,8 +87,8 @@ struct InferReturnTypesPattern : public RewritePattern {
6287 OperationState state (op->getLoc (), " hlo_test_infer.return_types" ,
6388 op->getOperands (), op->getResultTypes (),
6489 op->getAttrs ());
65- auto * newOp = rewriter.create (state);
66- for (const auto & it : llvm::enumerate (types))
90+ auto * newOp = rewriter.create (state);
91+ for (const auto & it : llvm::enumerate (types))
6792 newOp->setAttr ((StringRef (" types" ) + Twine (it.index ())).str (),
6893 TypeAttr::get (it.value ()));
6994 rewriter.replaceOp (op, {newOp->getResults ()});
@@ -72,10 +97,10 @@ struct InferReturnTypesPattern : public RewritePattern {
7297};
7398
7499struct ReifyReturnTypeShapesPattern : public RewritePattern {
75- explicit ReifyReturnTypeShapesPattern (MLIRContext * context)
100+ explicit ReifyReturnTypeShapesPattern (MLIRContext* context)
76101 : RewritePattern(" hlo_test_infer.reify_return_type_shapes" , 1 , context) {}
77- LogicalResult matchAndRewrite (Operation * op,
78- PatternRewriter & rewriter) const override {
102+ LogicalResult matchAndRewrite (Operation* op,
103+ PatternRewriter& rewriter) const override {
79104 if (op->getNumOperands () != 1 ) return failure ();
80105 auto definingOp =
81106 op->getOperand (0 ).getDefiningOp <InferShapedTypeOpInterface>();
@@ -89,7 +114,7 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
89114 }
90115};
91116
92- LogicalResult checkSpeculatability (PatternRewriter & rewriter, Operation * op,
117+ LogicalResult checkSpeculatability (PatternRewriter& rewriter, Operation* op,
93118 mlir::Speculation::Speculatability spec) {
94119 if (op->getNumOperands () != 1 ) return failure ();
95120 auto definingOp =
@@ -106,67 +131,86 @@ LogicalResult checkSpeculatability(PatternRewriter &rewriter, Operation *op,
106131}
107132
108133struct IsSpeculatablePattern : public RewritePattern {
109- explicit IsSpeculatablePattern (MLIRContext * context)
134+ explicit IsSpeculatablePattern (MLIRContext* context)
110135 : RewritePattern(" hlo_test_speculatability.is_speculatable" , 1 , context) {
111136 }
112- LogicalResult matchAndRewrite (Operation * op,
113- PatternRewriter & rewriter) const override {
137+ LogicalResult matchAndRewrite (Operation* op,
138+ PatternRewriter& rewriter) const override {
114139 return checkSpeculatability (rewriter, op, mlir::Speculation::Speculatable);
115140 }
116141};
117142
118143struct IsRecursivelySpeculatablePattern : public RewritePattern {
119- explicit IsRecursivelySpeculatablePattern (MLIRContext * context)
144+ explicit IsRecursivelySpeculatablePattern (MLIRContext* context)
120145 : RewritePattern(" hlo_test_speculatability.is_recursively_speculatable" ,
121146 1 , context) {}
122- LogicalResult matchAndRewrite (Operation * op,
123- PatternRewriter & rewriter) const override {
147+ LogicalResult matchAndRewrite (Operation* op,
148+ PatternRewriter& rewriter) const override {
124149 return checkSpeculatability (rewriter, op,
125150 mlir::Speculation::RecursivelySpeculatable);
126151 }
127152};
128153
129154struct IsNotSpeculatablePattern : public RewritePattern {
130- explicit IsNotSpeculatablePattern (MLIRContext * context)
155+ explicit IsNotSpeculatablePattern (MLIRContext* context)
131156 : RewritePattern(" hlo_test_speculatability.is_not_speculatable" , 1 ,
132157 context) {}
133- LogicalResult matchAndRewrite (Operation * op,
134- PatternRewriter & rewriter) const override {
158+ LogicalResult matchAndRewrite (Operation* op,
159+ PatternRewriter& rewriter) const override {
135160 return checkSpeculatability (rewriter, op,
136161 mlir::Speculation::NotSpeculatable);
137162 }
138163};
139164
165+ #define GEN_PASS_DEF_HLOTESTBROADCASTPASS
140166#define GEN_PASS_DEF_HLOTESTINFERPASS
141167#define GEN_PASS_DEF_HLOTESTSPECULATABILITYPASS
142168#include " stablehlo/tests/TestUtils.h.inc"
143169
170+ struct HloTestBroadcastPass
171+ : public impl::HloTestBroadcastPassBase<HloTestBroadcastPass> {
172+ LogicalResult initialize (MLIRContext* context) override {
173+ RewritePatternSet patterns (context);
174+ patterns.add <BroadcastValuesPattern>(context);
175+ patterns_ = std::move (patterns);
176+ return success ();
177+ }
178+
179+ void runOnOperation () override {
180+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns_))))
181+ return signalPassFailure ();
182+ }
183+
184+ private:
185+ FrozenRewritePatternSet patterns_;
186+ };
187+
144188struct HloTestInferPass : public impl ::HloTestInferPassBase<HloTestInferPass> {
145- LogicalResult initialize (MLIRContext * context) override {
146- RewritePatternSet patterns_ (context);
147- patterns_ .add <InferReturnTypesPattern>(context);
148- patterns_ .add <ReifyReturnTypeShapesPattern>(context);
149- patterns = std::move (patterns_ );
189+ LogicalResult initialize (MLIRContext* context) override {
190+ RewritePatternSet patterns (context);
191+ patterns .add <InferReturnTypesPattern>(context);
192+ patterns .add <ReifyReturnTypeShapesPattern>(context);
193+ patterns_ = std::move (patterns );
150194 return success ();
151195 }
152196
153197 void runOnOperation () override {
154- if (failed (applyPatternsGreedily (getOperation (), std::move (patterns ))))
198+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns_ ))))
155199 return signalPassFailure ();
156200 }
157201
158202 private:
159- FrozenRewritePatternSet patterns ;
203+ FrozenRewritePatternSet patterns_ ;
160204};
161205
162206struct HloTestSpeculatabilityPass
163207 : public impl::HloTestSpeculatabilityPassBase<HloTestSpeculatabilityPass> {
164- LogicalResult initialize (MLIRContext * context) override {
165- RewritePatternSet patterns_ (context);
166- patterns_ .add <IsSpeculatablePattern>(context);
167- patterns_ .add <IsNotSpeculatablePattern>(context);
168- patterns_ .add <IsRecursivelySpeculatablePattern>(context);
169- patterns = std::move (patterns_ );
208+ LogicalResult initialize (MLIRContext* context) override {
209+ RewritePatternSet patterns (context);
210+ patterns .add <IsSpeculatablePattern>(context);
211+ patterns .add <IsNotSpeculatablePattern>(context);
212+ patterns .add <IsRecursivelySpeculatablePattern>(context);
213+ patterns_ = std::move (patterns );
170214 return success ();
171215 }
172216
@@ -175,11 +219,11 @@ struct HloTestSpeculatabilityPass
175219 config.setMaxIterations (1 )
176220 .setUseTopDownTraversal (true )
177221 .setRegionSimplificationLevel (GreedySimplifyRegionLevel::Disabled);
178- (void )applyPatternsGreedily (getOperation (), std::move (patterns ));
222+ (void )applyPatternsGreedily (getOperation (), std::move (patterns_ ));
179223 }
180224
181225 private:
182- FrozenRewritePatternSet patterns ;
226+ FrozenRewritePatternSet patterns_ ;
183227};
184228
185229#define GEN_PASS_REGISTRATION
0 commit comments