Skip to content

Commit 9706d1f

Browse files
committed
fix deepTileMatmul
1 parent b8f89e9 commit 9706d1f

File tree

1 file changed

+29
-41
lines changed

1 file changed

+29
-41
lines changed

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ static Operation *findParentFillOp(Value val) {
272272
!isa<linalg::FillOp>(currentOp)) {
273273
currentOp = currentOp->getResult(0).getDefiningOp();
274274
}
275-
if (isa<linalg::FillOp>(currentOp)) {
275+
if (currentOp && isa<linalg::FillOp>(currentOp)) {
276276
return currentOp;
277277
}
278278

@@ -322,11 +322,10 @@ static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos,
322322
return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos];
323323
}
324324

325-
static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
326-
Operation *op,
327-
bool isExtract,
328-
SmallVector<int64_t> size,
329-
int shrinDimNum = 0) {
325+
static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
326+
Operation *op, bool isExtract,
327+
SmallVector<int64_t> size,
328+
int shrinDimNum = 0) {
330329
OpBuilder::InsertionGuard guard(rewriter);
331330
rewriter.setInsertionPoint(op);
332331
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
@@ -348,15 +347,12 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
348347
extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes,
349348
mixedStrides);
350349
}
351-
} else {
352-
return failure();
353350
}
354-
return mlir::success();
355351
}
356352

357-
static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
358-
Operation *op, Value source,
359-
SmallVector<int64_t> size) {
353+
static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op,
354+
Value source,
355+
SmallVector<int64_t> size) {
360356
OpBuilder::InsertionGuard guard(rewriter);
361357
rewriter.setInsertionPoint(op);
362358
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
@@ -369,10 +365,7 @@ static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
369365
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
370366
insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes,
371367
mixedStrides);
372-
} else {
373-
return failure();
374368
}
375-
return success();
376369
}
377370

378371
using InnermostFullResultCallBackFn = std::function<FailureOr<linalg::LinalgOp>(
@@ -691,7 +684,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
691684
linalg::LinalgOp originOp,
692685
linalg::LinalgOp currentOp,
693686
innerBodyGenerationOption &option) const {
694-
695687
mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()};
696688
auto operandDimTypes = getOprandDimType(originOp);
697689
auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig();
@@ -744,6 +736,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
744736
CInnermostDims =
745737
SmallVector<int64_t>{cfg.innerMostMBlock, cfg.innerMostNBlock};
746738
}
739+
747740
if (NDimNum > 1) {
748741
firstN = true;
749742
firstK = true;
@@ -780,21 +773,17 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
780773

781774
// update the extractSlice to static size, replace it with
782775
// useBlockedLayout when
783-
if (failed(setStaticSizeForExtractSliceOp(
784-
rewriter, currentOp.getDpsInits()[0].getDefiningOp(), true,
785-
CInnermostDims, MDimNum > 1 ? 2 : 0)) ||
786-
failed(setStaticSizeForExtractSliceOp(
787-
rewriter, currentOp.getDpsInputs()[1].getDefiningOp(), true,
788-
BInnermostDims, NDimNum > 1)) ||
789-
failed(setStaticSizeForExtractSliceOp(
790-
rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true,
791-
AInnermostDims, MDimNum > 1)) ||
792-
(currentOp.getDpsInits().size() > 1 &&
793-
failed(setStaticSizeForExtractSliceOp(
794-
rewriter, currentOp.getDpsInits()[1].getDefiningOp(), true,
795-
CInnermostDims, MDimNum > 1 ? 2 : 0)))) {
796-
return failure();
776+
setStaticSizeForExtractSliceOp(rewriter,
777+
currentOp.getDpsInputs()[1].getDefiningOp(),
778+
true, BInnermostDims, NDimNum > 1);
779+
setStaticSizeForExtractSliceOp(rewriter,
780+
currentOp.getDpsInputs()[0].getDefiningOp(),
781+
true, AInnermostDims, MDimNum > 1);
782+
for (auto init : currentOp.getDpsInits()) {
783+
setStaticSizeForExtractSliceOp(rewriter, init.getDefiningOp(), true,
784+
CInnermostDims, MDimNum > 1 ? 2 : 0);
797785
}
786+
798787
// View the tensor to brgemm required format
799788
Value dataOprand = tensorViewRankedTensor(
800789
rewriter,
@@ -841,10 +830,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
841830

842831
// Insert the result back to the original tensor
843832
for (Operation *user : currentOp->getResult(0).getUsers()) {
844-
if (failed(setStaticSizeForInsertSliceOp(rewriter, user, result,
845-
CInnermostDims))) {
846-
return failure();
847-
}
833+
setStaticSizeForInsertSliceOp(rewriter, user, result, CInnermostDims);
848834
}
849835

850836
if (option.needLowPrecisionCast) {
@@ -869,10 +855,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
869855
auto ifOp = eb.getLastOperaion();
870856
// set static size for the insertSliceOp of copyOp
871857
for (Operation *user : currentOp->getResult(1).getUsers()) {
872-
if (failed(setStaticSizeForInsertSliceOp(
873-
rewriter, user, ifOp->getResult(0), CInnermostDims))) {
874-
return failure();
875-
}
858+
setStaticSizeForInsertSliceOp(rewriter, user, ifOp->getResult(0),
859+
CInnermostDims);
876860
}
877861
rewriter.replaceOp(currentOp, {matmul->getResult(0), ifOp->getResult(0)});
878862
} else {
@@ -885,7 +869,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
885869
if (cfg.KThreads <= 1) {
886870
// if use k slicing, the fill op is still need to be kept for the reduce
887871
// init
888-
rewriter.replaceOp(fillOp, fillOp.getDpsInits()[0]);
872+
rewriter.replaceUsesWithIf(fillOp.getResult(0), fillOp.getDpsInits()[0],
873+
[&](OpOperand &operand) {
874+
return isa<LoopLikeOpInterface>(
875+
operand.getOwner());
876+
});
889877
}
890878

891879
rewriter.setInsertionPointAfter(currentOp);
@@ -954,8 +942,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
954942
}
955943

956944
// Step 2. Outer loop generation
957-
auto outerLoopResult = outerLoopGeneration(rewriter, linalgOp, cfg,
958-
isa<linalg::FillOp>(fillOp));
945+
auto outerLoopResult = outerLoopGeneration(
946+
rewriter, linalgOp, cfg, fillOp && isa<linalg::FillOp>(fillOp));
959947
if (failed(outerLoopResult)) {
960948
return failure();
961949
}

0 commit comments

Comments
 (0)