Skip to content

Commit 5f56b96

Browse files
author
ZhangYan
committed
update deeptile
1 parent fa07870 commit 5f56b96

File tree

4 files changed

+99
-12
lines changed

4 files changed

+99
-12
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,29 @@ struct SystemDesc {
2828
// get runtime OMP_NUM_THREADS
2929
uint32_t getNumThreads() {
3030
char *numThreads = getenv("OMP_NUM_THREADS");
31-
if (numThreads) {
31+
if (!threads_limited && numThreads) {
3232
return std::stoi(numThreads);
3333
}
34+
return curThreads;
35+
}
36+
37+
// set the expected threads
38+
void limitOnSingleNode(uint32_t numa_node) {
39+
char *cacheSize = getenv("NUMA_THREADS");
40+
if (cacheSize) {
41+
curThreads = std::stoi(cacheSize);
42+
threads_limited = true;
43+
}
44+
}
45+
46+
uint32_t getNumNodes() {
47+
char *numThreads = getenv("OMP_NUM_THREADS");
48+
if (threads_limited && numThreads) {
49+
return std::stoi(numThreads) / curThreads;
50+
}
3451
return 1;
3552
}
53+
3654
// get cache size by cacheLevel
3755
size_t getCacheSize(uint8_t cacheLevel) {
3856
if (cacheLevel == 1) {
@@ -57,6 +75,10 @@ struct SystemDesc {
5775
SmallVector<size_t> getContractionOperationMaxVectorLength() {
5876
return {512UL, 512UL};
5977
}
78+
79+
private:
80+
uint32_t curThreads = 1;
81+
bool threads_limited = false;
6082
};
6183

6284
struct MatmulConfig {

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,12 @@ previous matmul
345345
MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
346346
SystemDesc sysDesc;
347347
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(root)) {
348+
// Check if the operation has an attribute named 'splited'
349+
auto splitedAttr = linalgOp->getAttrOfType<IntegerAttr>("splited");
350+
if (splitedAttr) {
351+
sysDesc.limitOnSingleNode(splitedAttr.getInt());
352+
llvm::outs() << "splited mm, and should be allocated on numa node 0.\n";
353+
}
348354
auto oprandDimType = *getOprandDimType(linalgOp);
349355
// get the origin M,N,K size
350356
auto MDimTypeIdx = extractDimTypeIdx(oprandDimType[0], DimType::M);

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
471471
else
472472
tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile);
473473
}
474+
474475
SmallVector<Range> loopRanges =
475476
cast<TilingInterface>(currentOp.getOperation()).getIterationDomain(b);
476477
OpBuilder::InsertionGuard guard(b);
@@ -482,7 +483,6 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
482483
tileSizes[idx] = loopRanges[idx].size;
483484
}
484485
}
485-
486486
SmallVector<OpFoldResult> newParallelDims;
487487
for (auto i = 0UL; i < reductionDims.size(); i++) {
488488
newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i));
@@ -595,6 +595,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
595595
auto NOuterBlockSize = NDimPos.size() > 1
596596
? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
597597
: cfg.NBlock;
598+
// Outermost Numa loop
599+
option.nestedTileSizes.emplace_back(
600+
SmallVector<size_t>{uint32_t(MFirstDim / 2)});
601+
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp);
602+
option.loopDim.emplace_back(SmallVector<size_t>{MDimPos[0]});
598603
// Outer
599604
option.nestedTileSizes.emplace_back(SmallVector<size_t>{
600605
MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});

lib/gc/Transforms/Tiling.cpp

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,22 @@ FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
782782
return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
783783
}
784784

785+
FailureOr<TilingResult>
786+
getTiledImplementationOnNuma(Operation *op, OpBuilder &b,
787+
ArrayRef<OpFoldResult> offsets,
788+
ArrayRef<OpFoldResult> sizes) {
789+
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
790+
// specified could lead to out of bounds accesses.
791+
Location loc = op->getLoc();
792+
LinalgOp linalgOp = cast<LinalgOp>(op);
793+
SmallVector<Value> valuesToTile = linalgOp->getOperands();
794+
795+
SmallVector<Type> resultTensorTypes =
796+
getTensorOutputTypes(linalgOp, valuesToTile);
797+
Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, valuesToTile);
798+
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
799+
}
800+
785801
FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
786802
RewriterBase &b, PartialReductionOpInterface op,
787803
ArrayRef<OpFoldResult> threadNums, ArrayRef<OpFoldResult> tileSizes,
@@ -964,6 +980,16 @@ FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
964980
// 4.b. Clone the op and update init operands.
965981
// We cannot use a IRMapping here because it can replace
966982
// different OpOperands with the same value.
983+
bool isNumaLoop = false;
984+
if (tileSizes.size() == iterationDomain.size()) {
985+
for (auto [idx, tile] : llvm::enumerate(tileSizes)) {
986+
if (idx == 0 && tileSizes[idx] == iterationDomain[idx].size)
987+
break;
988+
if (idx > 0 && tileSizes[idx] != iterationDomain[idx].size)
989+
break;
990+
isNumaLoop = true;
991+
}
992+
}
967993
Operation *clonedOp = b.clone(*op.getOperation());
968994
b.modifyOpInPlace(clonedOp, [&]() {
969995
for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
@@ -974,17 +1000,32 @@ FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
9741000
});
9751001
// 5. Tile the cloned op and delete the clone.
9761002
if (tileSizes.empty() || threadNums.empty()) {
977-
FailureOr<TilingResult> tilingResult =
978-
cast<TilingInterface>(clonedOp).getTiledImplementation(
979-
b, tiledOffsets, tiledSizes);
980-
if (failed(tilingResult))
981-
return clonedOp->emitError("Failed to tile op: ");
982-
if (tilingResult->tiledOps.size() != 1) {
983-
return clonedOp->emitError("expected a single produced tiled op, got ")
984-
<< tilingResult->tiledOps.size();
1003+
if (!isNumaLoop) {
1004+
FailureOr<TilingResult> tilingResult =
1005+
cast<TilingInterface>(clonedOp).getTiledImplementation(
1006+
b, tiledOffsets, tiledSizes);
1007+
if (failed(tilingResult))
1008+
return clonedOp->emitError("Failed to tile op: ");
1009+
if (tilingResult->tiledOps.size() != 1) {
1010+
return clonedOp->emitError(
1011+
"expected a single produced tiled op, got ")
1012+
<< tilingResult->tiledOps.size();
1013+
}
1014+
tiledOp = tilingResult->tiledOps.front();
1015+
tilingResults = tilingResult->tiledValues;
1016+
} else {
1017+
FailureOr<TilingResult> tilingResult = getTiledImplementationOnNuma(
1018+
cast<TilingInterface>(clonedOp), b, tiledOffsets, tiledSizes);
1019+
if (failed(tilingResult))
1020+
return clonedOp->emitError("Failed to tile op: ");
1021+
if (tilingResult->tiledOps.size() != 1) {
1022+
return clonedOp->emitError(
1023+
"expected a single produced tiled op, got ")
1024+
<< tilingResult->tiledOps.size();
1025+
}
1026+
tiledOp = tilingResult->tiledOps.front();
1027+
tilingResults = tilingResult->tiledValues;
9851028
}
986-
tiledOp = tilingResult->tiledOps.front();
987-
tilingResults = tilingResult->tiledValues;
9881029
} else {
9891030
LinalgTilingOptions options;
9901031
FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
@@ -1039,6 +1080,19 @@ FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall(
10391080
nonZeroDimIdx++;
10401081
}
10411082
}
1083+
if (auto attr = resultSizesRank[0].dyn_cast<Attribute>()) {
1084+
if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1085+
if (intAttr.getInt() == 16)
1086+
resultSizesRank[0] = b.getIndexAttr(32);
1087+
}
1088+
} else if (auto value = resultSizesRank[0].dyn_cast<Value>()) {
1089+
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>()) {
1090+
if (auto intAttr = constantOp.getValue().dyn_cast<IntegerAttr>()) {
1091+
if (intAttr.getInt() == 16)
1092+
resultSizesRank[0] = b.getIndexAttr(32);
1093+
}
1094+
}
1095+
}
10421096
if (hasReductionThreads) {
10431097
for (auto [parallelDims, redVar] :
10441098
llvm::zip(constantNewParallelDims, reductionInductionVars)) {

0 commit comments

Comments
 (0)