@@ -782,6 +782,22 @@ FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
782
782
return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
783
783
}
784
784
785
+ FailureOr<TilingResult>
786
+ getTiledImplementationOnNuma (Operation *op, OpBuilder &b,
787
+ ArrayRef<OpFoldResult> offsets,
788
+ ArrayRef<OpFoldResult> sizes) {
789
+ // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
790
+ // specified could lead to out of bounds accesses.
791
+ Location loc = op->getLoc ();
792
+ LinalgOp linalgOp = cast<LinalgOp>(op);
793
+ SmallVector<Value> valuesToTile = linalgOp->getOperands ();
794
+
795
+ SmallVector<Type> resultTensorTypes =
796
+ getTensorOutputTypes (linalgOp, valuesToTile);
797
+ Operation *tiledOp = clone (b, linalgOp, resultTensorTypes, valuesToTile);
798
+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
799
+ }
800
+
785
801
FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall (
786
802
RewriterBase &b, PartialReductionOpInterface op,
787
803
ArrayRef<OpFoldResult> threadNums, ArrayRef<OpFoldResult> tileSizes,
@@ -964,6 +980,16 @@ FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
964
980
// 4.b. Clone the op and update init operands.
965
981
// We cannot use a IRMapping here because it can replace
966
982
// different OpOperands with the same value.
983
+ bool isNumaLoop = false ;
984
+ if (tileSizes.size () == iterationDomain.size ()) {
985
+ for (auto [idx, tile] : llvm::enumerate (tileSizes)) {
986
+ if (idx == 0 && tileSizes[idx] == iterationDomain[idx].size )
987
+ break ;
988
+ if (idx > 0 && tileSizes[idx] != iterationDomain[idx].size )
989
+ break ;
990
+ isNumaLoop = true ;
991
+ }
992
+ }
967
993
Operation *clonedOp = b.clone (*op.getOperation ());
968
994
b.modifyOpInPlace (clonedOp, [&]() {
969
995
for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal (
@@ -974,17 +1000,32 @@ FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
974
1000
});
975
1001
// 5. Tile the cloned op and delete the clone.
976
1002
if (tileSizes.empty () || threadNums.empty ()) {
977
- FailureOr<TilingResult> tilingResult =
978
- cast<TilingInterface>(clonedOp).getTiledImplementation (
979
- b, tiledOffsets, tiledSizes);
980
- if (failed (tilingResult))
981
- return clonedOp->emitError (" Failed to tile op: " );
982
- if (tilingResult->tiledOps .size () != 1 ) {
983
- return clonedOp->emitError (" expected a single produced tiled op, got " )
984
- << tilingResult->tiledOps .size ();
1003
+ if (!isNumaLoop) {
1004
+ FailureOr<TilingResult> tilingResult =
1005
+ cast<TilingInterface>(clonedOp).getTiledImplementation (
1006
+ b, tiledOffsets, tiledSizes);
1007
+ if (failed (tilingResult))
1008
+ return clonedOp->emitError (" Failed to tile op: " );
1009
+ if (tilingResult->tiledOps .size () != 1 ) {
1010
+ return clonedOp->emitError (
1011
+ " expected a single produced tiled op, got " )
1012
+ << tilingResult->tiledOps .size ();
1013
+ }
1014
+ tiledOp = tilingResult->tiledOps .front ();
1015
+ tilingResults = tilingResult->tiledValues ;
1016
+ } else {
1017
+ FailureOr<TilingResult> tilingResult = getTiledImplementationOnNuma (
1018
+ cast<TilingInterface>(clonedOp), b, tiledOffsets, tiledSizes);
1019
+ if (failed (tilingResult))
1020
+ return clonedOp->emitError (" Failed to tile op: " );
1021
+ if (tilingResult->tiledOps .size () != 1 ) {
1022
+ return clonedOp->emitError (
1023
+ " expected a single produced tiled op, got " )
1024
+ << tilingResult->tiledOps .size ();
1025
+ }
1026
+ tiledOp = tilingResult->tiledOps .front ();
1027
+ tilingResults = tilingResult->tiledValues ;
985
1028
}
986
- tiledOp = tilingResult->tiledOps .front ();
987
- tilingResults = tilingResult->tiledValues ;
988
1029
} else {
989
1030
LinalgTilingOptions options;
990
1031
FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
@@ -1039,6 +1080,19 @@ FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
1039
1080
nonZeroDimIdx++;
1040
1081
}
1041
1082
}
1083
+ if (auto attr = resultSizesRank[0 ].dyn_cast <Attribute>()) {
1084
+ if (auto intAttr = attr.dyn_cast <IntegerAttr>()) {
1085
+ if (intAttr.getInt () == 16 )
1086
+ resultSizesRank[0 ] = b.getIndexAttr (32 );
1087
+ }
1088
+ } else if (auto value = resultSizesRank[0 ].dyn_cast <Value>()) {
1089
+ if (auto constantOp = value.getDefiningOp <arith::ConstantOp>()) {
1090
+ if (auto intAttr = constantOp.getValue ().dyn_cast <IntegerAttr>()) {
1091
+ if (intAttr.getInt () == 16 )
1092
+ resultSizesRank[0 ] = b.getIndexAttr (32 );
1093
+ }
1094
+ }
1095
+ }
1042
1096
if (hasReductionThreads) {
1043
1097
for (auto [parallelDims, redVar] :
1044
1098
llvm::zip (constantNewParallelDims, reductionInductionVars)) {
0 commit comments