Skip to content

Commit ccd02f2

Browse files
committed
use expand/collapse_shape to do rank alter
1 parent a205731 commit ccd02f2

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -275,19 +275,22 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
275275
SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets();
276276
SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes();
277277
SmallVector<OpFoldResult> mixedStrides = extractSlice.getMixedStrides();
278+
auto targetTensor = mlir::RankedTensorType::get(
279+
SmallVector<int64_t>(size.begin() + shrinDimNum, size.end()),
280+
extractSlice.getResult().getType().getElementType());
278281
for (auto &&[i, s] : llvm::enumerate(size))
279282
mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s);
280-
if (shrinDimNum > 0)
281-
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
282-
extractSlice,
283-
mlir::RankedTensorType::get(
284-
SmallVector<int64_t>(size.begin() + shrinDimNum, size.end()),
285-
extractSlice.getResult().getType().getElementType()),
286-
extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides);
287-
else
288-
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
289-
extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes,
290-
mixedStrides);
283+
Operation *newExtractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
284+
extractSlice->getLoc(), extractSlice.getSource(), mixedOffsets,
285+
mixedSizes, mixedStrides);
286+
if (shrinDimNum > 0) {
287+
rewriter.setInsertionPointAfter(newExtractSliceOp);
288+
Value viewResult = tensorViewRankedTensor(
289+
rewriter, targetTensor, newExtractSliceOp->getResult(0));
290+
rewriter.replaceOp(extractSlice, viewResult);
291+
} else {
292+
rewriter.replaceOp(extractSlice, newExtractSliceOp);
293+
}
291294
}
292295
}
293296

@@ -304,9 +307,12 @@ static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op,
304307
SmallVector<OpFoldResult> mixedStrides = insertSlice.getMixedStrides();
305308
for (auto &&[i, s] : llvm::enumerate(size))
306309
mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s);
310+
auto targetTensor = mlir::RankedTensorType::get(
311+
size, insertSlice.getDest().getType().getElementType());
312+
Value viewResult = tensorViewRankedTensor(rewriter, targetTensor, source);
307313
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
308-
insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes,
309-
mixedStrides);
314+
insertSlice, viewResult, insertSlice.getDest(), mixedOffsets,
315+
mixedSizes, mixedStrides);
310316
}
311317
}
312318

test/mlir/test/gc/Transforms/deepTileContractionNamedOp.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,13 @@ func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<12
5555
// CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1]
5656
// CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1]
5757
// CHECK: scf.for
58-
// CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1]
58+
// CHECK: tensor.collapse_shape {{.*}} tensor<1x8x32x32xbf16> into tensor<8x32x32xbf16>
5959
// CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1]
60+
// CHECK: tensor.collapse_shape {{.*}} tensor<1x8x16x32x2xbf16> into tensor<8x16x32x2xbf16>
6061
// CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1]
62+
// CHECK: tensor.collapse_shape {{.*}} tensor<1x1x32x32xf32> into tensor<32x32xf32>
6163
// CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1]
64+
// CHECK: tensor.collapse_shape {{.*}} tensor<1x1x32x32xbf16> into tensor<32x32xbf16>
6265
// CHECK: scf.if
6366
// CHECK: linalg.fill
6467
// CHECK: linalgx.batch_reduce_matmul_vnni
@@ -92,6 +95,7 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12
9295
// CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1]
9396
// CHECK: scf.for
9497
// CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1]
98+
// CHECK: tensor.collapse_shape {{.*}} tensor<1x8x16x32x2xbf16> into tensor<8x16x32x2xbf16>
9599
// CHECK: tensor.extract_slice {{.*}} [32, 32] [1, 1]
96100
// CHECK: linalg.transpose {{.*}} permutation = [1, 0, 2]
97101
// CHECK: scf.if

0 commit comments

Comments
 (0)