@@ -272,7 +272,7 @@ static Operation *findParentFillOp(Value val) {
272
272
!isa<linalg::FillOp>(currentOp)) {
273
273
currentOp = currentOp->getResult (0 ).getDefiningOp ();
274
274
}
275
- if (isa<linalg::FillOp>(currentOp)) {
275
+ if (currentOp && isa<linalg::FillOp>(currentOp)) {
276
276
return currentOp;
277
277
}
278
278
@@ -322,11 +322,10 @@ static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos,
322
322
return linalgOp.getShape (linalgOp.getDpsInputOperand (operandIdx))[dimPos];
323
323
}
324
324
325
- static LogicalResult setStaticSizeForExtractSliceOp (RewriterBase &rewriter,
326
- Operation *op,
327
- bool isExtract,
328
- SmallVector<int64_t > size,
329
- int shrinDimNum = 0 ) {
325
+ static void setStaticSizeForExtractSliceOp (RewriterBase &rewriter,
326
+ Operation *op, bool isExtract,
327
+ SmallVector<int64_t > size,
328
+ int shrinDimNum = 0 ) {
330
329
OpBuilder::InsertionGuard guard (rewriter);
331
330
rewriter.setInsertionPoint (op);
332
331
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
@@ -348,15 +347,12 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
348
347
extractSlice, extractSlice.getSource (), mixedOffsets, mixedSizes,
349
348
mixedStrides);
350
349
}
351
- } else {
352
- return failure ();
353
350
}
354
- return mlir::success ();
355
351
}
356
352
357
- static LogicalResult setStaticSizeForInsertSliceOp (RewriterBase &rewriter,
358
- Operation *op, Value source,
359
- SmallVector<int64_t > size) {
353
+ static void setStaticSizeForInsertSliceOp (RewriterBase &rewriter, Operation *op ,
354
+ Value source,
355
+ SmallVector<int64_t > size) {
360
356
OpBuilder::InsertionGuard guard (rewriter);
361
357
rewriter.setInsertionPoint (op);
362
358
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
@@ -369,10 +365,7 @@ static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
369
365
rewriter.replaceOpWithNewOp <tensor::InsertSliceOp>(
370
366
insertSlice, source, insertSlice.getDest (), mixedOffsets, mixedSizes,
371
367
mixedStrides);
372
- } else {
373
- return failure ();
374
368
}
375
- return success ();
376
369
}
377
370
378
371
using InnermostFullResultCallBackFn = std::function<FailureOr<linalg::LinalgOp>(
@@ -691,7 +684,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
691
684
linalg::LinalgOp originOp,
692
685
linalg::LinalgOp currentOp,
693
686
innerBodyGenerationOption &option) const {
694
-
695
687
mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc ()};
696
688
auto operandDimTypes = getOprandDimType (originOp);
697
689
auto cfg = MatmulConfigAnalysis (originOp.getOperation ()).getConfig ();
@@ -744,6 +736,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
744
736
CInnermostDims =
745
737
SmallVector<int64_t >{cfg.innerMostMBlock , cfg.innerMostNBlock };
746
738
}
739
+
747
740
if (NDimNum > 1 ) {
748
741
firstN = true ;
749
742
firstK = true ;
@@ -780,21 +773,17 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
780
773
781
774
// update the extractSlice to static size, replace it with
782
775
// useBlockedLayout when
783
- if (failed (setStaticSizeForExtractSliceOp (
784
- rewriter, currentOp.getDpsInits ()[0 ].getDefiningOp (), true ,
785
- CInnermostDims, MDimNum > 1 ? 2 : 0 )) ||
786
- failed (setStaticSizeForExtractSliceOp (
787
- rewriter, currentOp.getDpsInputs ()[1 ].getDefiningOp (), true ,
788
- BInnermostDims, NDimNum > 1 )) ||
789
- failed (setStaticSizeForExtractSliceOp (
790
- rewriter, currentOp.getDpsInputs ()[0 ].getDefiningOp (), true ,
791
- AInnermostDims, MDimNum > 1 )) ||
792
- (currentOp.getDpsInits ().size () > 1 &&
793
- failed (setStaticSizeForExtractSliceOp (
794
- rewriter, currentOp.getDpsInits ()[1 ].getDefiningOp (), true ,
795
- CInnermostDims, MDimNum > 1 ? 2 : 0 )))) {
796
- return failure ();
776
+ setStaticSizeForExtractSliceOp (rewriter,
777
+ currentOp.getDpsInputs ()[1 ].getDefiningOp (),
778
+ true , BInnermostDims, NDimNum > 1 );
779
+ setStaticSizeForExtractSliceOp (rewriter,
780
+ currentOp.getDpsInputs ()[0 ].getDefiningOp (),
781
+ true , AInnermostDims, MDimNum > 1 );
782
+ for (auto init : currentOp.getDpsInits ()) {
783
+ setStaticSizeForExtractSliceOp (rewriter, init.getDefiningOp (), true ,
784
+ CInnermostDims, MDimNum > 1 ? 2 : 0 );
797
785
}
786
+
798
787
// View the tensor to brgemm required format
799
788
Value dataOprand = tensorViewRankedTensor (
800
789
rewriter,
@@ -841,10 +830,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
841
830
842
831
// Insert the result back to the original tensor
843
832
for (Operation *user : currentOp->getResult (0 ).getUsers ()) {
844
- if (failed (setStaticSizeForInsertSliceOp (rewriter, user, result,
845
- CInnermostDims))) {
846
- return failure ();
847
- }
833
+ setStaticSizeForInsertSliceOp (rewriter, user, result, CInnermostDims);
848
834
}
849
835
850
836
if (option.needLowPrecisionCast ) {
@@ -869,10 +855,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
869
855
auto ifOp = eb.getLastOperaion ();
870
856
// set static size for the insertSliceOp of copyOp
871
857
for (Operation *user : currentOp->getResult (1 ).getUsers ()) {
872
- if (failed (setStaticSizeForInsertSliceOp (
873
- rewriter, user, ifOp->getResult (0 ), CInnermostDims))) {
874
- return failure ();
875
- }
858
+ setStaticSizeForInsertSliceOp (rewriter, user, ifOp->getResult (0 ),
859
+ CInnermostDims);
876
860
}
877
861
rewriter.replaceOp (currentOp, {matmul->getResult (0 ), ifOp->getResult (0 )});
878
862
} else {
@@ -885,7 +869,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
885
869
if (cfg.KThreads <= 1 ) {
886
870
// if use k slicing, the fill op is still need to be kept for the reduce
887
871
// init
888
- rewriter.replaceOp (fillOp, fillOp.getDpsInits ()[0 ]);
872
+ rewriter.replaceUsesWithIf (fillOp.getResult (0 ), fillOp.getDpsInits ()[0 ],
873
+ [&](OpOperand &operand) {
874
+ return isa<LoopLikeOpInterface>(
875
+ operand.getOwner ());
876
+ });
889
877
}
890
878
891
879
rewriter.setInsertionPointAfter (currentOp);
@@ -954,8 +942,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
954
942
}
955
943
956
944
// Step 2. Outer loop generation
957
- auto outerLoopResult = outerLoopGeneration (rewriter, linalgOp, cfg,
958
- isa<linalg::FillOp>(fillOp));
945
+ auto outerLoopResult = outerLoopGeneration (
946
+ rewriter, linalgOp, cfg, fillOp && isa<linalg::FillOp>(fillOp));
959
947
if (failed (outerLoopResult)) {
960
948
return failure ();
961
949
}
0 commit comments