@@ -275,19 +275,22 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
275
275
SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets ();
276
276
SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes ();
277
277
SmallVector<OpFoldResult> mixedStrides = extractSlice.getMixedStrides ();
278
+ auto targetTensor = mlir::RankedTensorType::get (
279
+ SmallVector<int64_t >(size.begin () + shrinDimNum, size.end ()),
280
+ extractSlice.getResult ().getType ().getElementType ());
278
281
for (auto &&[i, s] : llvm::enumerate (size))
279
282
mixedSizes[i] = getAsIndexOpFoldResult (rewriter.getContext (), s);
280
- if (shrinDimNum > 0 )
281
- rewriter. replaceOpWithNewOp <tensor::ExtractSliceOp>(
282
- extractSlice,
283
- mlir::RankedTensorType::get (
284
- SmallVector< int64_t >(size. begin () + shrinDimNum, size. end ()),
285
- extractSlice. getResult (). getType (). getElementType ()),
286
- extractSlice. getSource (), mixedOffsets, mixedSizes, mixedStrides );
287
- else
288
- rewriter. replaceOpWithNewOp <tensor::ExtractSliceOp>(
289
- extractSlice, extractSlice. getSource (), mixedOffsets, mixedSizes,
290
- mixedStrides);
283
+ Operation *newExtractSliceOp = rewriter. create <tensor::ExtractSliceOp>(
284
+ extractSlice-> getLoc (), extractSlice. getSource (), mixedOffsets,
285
+ mixedSizes, mixedStrides);
286
+ if (shrinDimNum > 0 ) {
287
+ rewriter. setInsertionPointAfter (newExtractSliceOp);
288
+ Value viewResult = tensorViewRankedTensor (
289
+ rewriter, targetTensor, newExtractSliceOp-> getResult ( 0 ) );
290
+ rewriter. replaceOp (extractSlice, viewResult);
291
+ } else {
292
+ rewriter. replaceOp (extractSlice, newExtractSliceOp);
293
+ }
291
294
}
292
295
}
293
296
@@ -304,9 +307,12 @@ static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op,
304
307
SmallVector<OpFoldResult> mixedStrides = insertSlice.getMixedStrides ();
305
308
for (auto &&[i, s] : llvm::enumerate (size))
306
309
mixedSizes[i] = getAsIndexOpFoldResult (rewriter.getContext (), s);
310
+ auto targetTensor = mlir::RankedTensorType::get (
311
+ size, insertSlice.getDest ().getType ().getElementType ());
312
+ Value viewResult = tensorViewRankedTensor (rewriter, targetTensor, source);
307
313
rewriter.replaceOpWithNewOp <tensor::InsertSliceOp>(
308
- insertSlice, source , insertSlice.getDest (), mixedOffsets, mixedSizes ,
309
- mixedStrides);
314
+ insertSlice, viewResult , insertSlice.getDest (), mixedOffsets,
315
+ mixedSizes, mixedStrides);
310
316
}
311
317
}
312
318
0 commit comments