From 54bdb091385c7dc1a05f2db93953d713a5c6044f Mon Sep 17 00:00:00 2001 From: ZhangYan Date: Tue, 2 Jul 2024 19:34:04 -0700 Subject: [PATCH 1/4] modify --- scripts/generate_single_matmul_mlir.py | 18 ++++++++++++++++++ scripts/run_all.sh | 21 ++++++++++++--------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/scripts/generate_single_matmul_mlir.py b/scripts/generate_single_matmul_mlir.py index 2e0f1df71..ac5357ee7 100644 --- a/scripts/generate_single_matmul_mlir.py +++ b/scripts/generate_single_matmul_mlir.py @@ -31,6 +31,22 @@ def generate_single_matmul_mlir(M, N, K): ''' return mlir_code +def generate_single_matmul_mlir_wo_data(M, N, K): + mat_A = numpy.random.rand(M, K) + mat_B = numpy.random.rand(K, N) + mat_C = numpy.dot(mat_A, mat_B) + block_start = "{" + block_end = "}" + mlir_code = f''' +func.func @main_entry(%arg0: tensor<{M}x{K}xf32>, %arg1: tensor<{K}x{N}xf32> ) -> tensor<{M}x{N}xf32> attributes {block_start}llvm.emit_c_interface{block_end} {block_start} + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<{M}x{N}xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}xf32>, tensor<{K}x{N}xf32>) outs(%1 : tensor<{M}x{N}xf32>) -> tensor<{M}x{N}xf32> + return %2 : tensor<{M}x{N}xf32> +{block_end} + ''' + return mlir_code def generate_mlir_bf16_2dx4d(M, N, K, tile_m = 32, tile_n = 32, tile_k = 32, dtype_size=2): M_block = (M-1) // tile_m + 1 @@ -123,6 +139,8 @@ def generate_mlir_f32_4dx4d_generic(M, N, K, tile_m = 32, tile_n = 32, tile_k = args = parser.parse_args() if args.mode == "correctness": code = generate_single_matmul_mlir(args.M, args.N, args.K) + elif args.mode == "f32_2dx2d": + code = generate_single_matmul_mlir_wo_data(args.M, args.N, args.K) elif args.mode == "bf16_2dx4d": code = generate_mlir_bf16_2dx4d(args.M, args.N, args.K, args.tile_m, args.tile_n, args.tile_k) elif args.mode == "bf16_4dx4d": diff --git a/scripts/run_all.sh b/scripts/run_all.sh index 8b3f9416d..6216ef85e 100644 --- a/scripts/run_all.sh +++ b/scripts/run_all.sh @@ -1,15 +1,18 @@ set -ex export PYTHONPATH=`pwd`/python_packages/tpp_core -export LD_PRELOAD=/home/zhicong/miniforge3/lib/libiomp5.so -export MLIR_RUNNER_UTILS=/home/zhicong/code/llvm-project/llvm-install/lib/libmlir_runner_utils.so -export MLIR_C_RUNNER_UTILS=/home/zhicong/code/llvm-project/llvm-install/lib/libmlir_runner_utils.so +export LD_PRELOAD=/home/zhangyan/miniforge3/envs/gc/lib/libiomp5.so +export MLIR_RUNNER_UTILS=/home/zhangyan/graph_compiler_v2/externals/llvm-project/llvm-install/lib/libmlir_runner_utils.so +export MLIR_C_RUNNER_UTILS=/home/zhangyan/graph_compiler_v2/externals/llvm-project/llvm-install/lib/libmlir_runner_utils.so +BUILD_DIR=${PROJECT_DIR}/build + export L1_CACHE_SIZE=49152 export L2_CACHE_SZIE=2097152 -export L3_CACHE_SIZE=1966080 -PROJECT_DIR=/home/zhicong/code/gc-pipeline -BUILD_DIR=${PROJECT_DIR}/build +export L3_CACHE_SIZE=335544320 +export PROJECT_DIR=/home/zhangyan/graph_compiler_v2 + export PYTHONPATH=${PROJECT_DIR}/build/python_packages/gc_mlir_core -export LD_PRELOAD="/home/zhicong/miniforge3/lib/libiomp5.so ${PROJECT_DIR}/build/lib/libGCCpuRuntime.so" +export LD_PRELOAD=$LD_PRELOAD:"/home/zhangyan/miniforge3/envs/gc/lib/libiomp5.so" +export LD_PRELOAD=$LD_PRELOAD:"/home/zhangyan/graph_compiler_v2/build/lib/libGCCpuRuntime.so" export MLIR_RUNNER_UTILS=${PROJECT_DIR}/externals/llvm-project/build/lib/libmlir_runner_utils.so export MLIR_C_RUNNER_UTILS=${PROJECT_DIR}/externals/llvm-project/build/lib/libmlir_c_runner_utils.so @@ -20,14 +23,14 @@ cd $BUILD_DIR echo "thread, dtype, bs, hidden_size, tile, time(ms), GFlops, extra, cmd" for tile in 32 64 128 do -for thread in 1 32 56 +for thread in 32 do for mode in f32_4dx4d_generic bf16_4dx4d do for hidden_size in 4096x4096 4096x11008 11008x4096 4096x32000 do -for bs in 1 16 32 64 512 +for bs in 32 do export OMP_NUM_THREADS=$thread M_SIZE=$bs From fa078708cc05bab3555549829bea52670ffdee1a Mon Sep 17 00:00:00 2001 From: ZhangYan Date: Wed, 3 Jul 2024 01:48:00 -0700 Subject: [PATCH 2/4] can run spliit on K --- lib/gc/Transforms/DeepTileContractionNamedOp.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index a13f9c29f..ee234c4d9 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -906,8 +906,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { llvm::isa(linalgOp) || llvm::isa(linalgOp) || llvm::isa(linalgOp) || - llvm::isa(linalgOp) || - llvm::isa(linalgOp); + llvm::isa(linalgOp); // || + // llvm::isa(linalgOp); } LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, From 5f56b96f0b68f13d77742be4af1eb6ed6d30e41d Mon Sep 17 00:00:00 2001 From: ZhangYan Date: Mon, 8 Jul 2024 08:43:13 -0700 Subject: [PATCH 3/4] update deeptile --- include/gc/Analysis/MatmulConfigAnalysis.h | 24 +++++- lib/gc/Analysis/MatmulConfigAnalysis.cpp | 6 ++ .../Transforms/DeepTileContractionNamedOp.cpp | 7 +- lib/gc/Transforms/Tiling.cpp | 74 ++++++++++++++++--- 4 files changed, 99 insertions(+), 12 deletions(-) diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h index 7b91f7908..2ac7de048 100644 --- a/include/gc/Analysis/MatmulConfigAnalysis.h +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -28,11 +28,29 @@ struct SystemDesc { // get runtime OMP_NUM_THREADS uint32_t getNumThreads() { char *numThreads = getenv("OMP_NUM_THREADS"); - if (numThreads) { + if (!threads_limited && numThreads) { return std::stoi(numThreads); } + return curThreads; + } + + // set the expected threads + void limitOnSingleNode(uint32_t numa_node) { + char *cacheSize = getenv("NUMA_THREADS"); + if (cacheSize) { + curThreads = std::stoi(cacheSize); + threads_limited = true; + } + } + + uint32_t getNumNodes() { + char *numThreads = getenv("OMP_NUM_THREADS"); + if (threads_limited && numThreads) { + return std::stoi(numThreads) / curThreads; + } return 1; } + // get cache size by cacheLevel size_t getCacheSize(uint8_t cacheLevel) { if (cacheLevel == 1) { @@ -57,6 +75,10 @@ struct SystemDesc { SmallVector getContractionOperationMaxVectorLength() { return {512UL, 512UL}; } + +private: + uint32_t curThreads = 1; + bool threads_limited = false; }; struct MatmulConfig { diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index 3440cd3ec..a2132e5bf 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -345,6 +345,12 @@ previous matmul MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { SystemDesc sysDesc; if (auto linalgOp = dyn_cast(root)) { + // Check if the operation has an attribute named 'splited' + auto splitedAttr = linalgOp->getAttrOfType("splited"); + if (splitedAttr) { + sysDesc.limitOnSingleNode(splitedAttr.getInt()); + llvm::outs() << "splited mm, and should be allocated on numa node 0.\n"; + } auto oprandDimType = *getOprandDimType(linalgOp); // get the origin M,N,K size auto MDimTypeIdx = extractDimTypeIdx(oprandDimType[0], DimType::M); diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index ee234c4d9..b2a444abb 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -471,6 +471,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, else tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); } + SmallVector loopRanges = cast(currentOp.getOperation()).getIterationDomain(b); OpBuilder::InsertionGuard guard(b); @@ -482,7 +483,6 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, tileSizes[idx] = loopRanges[idx].size; } } - SmallVector newParallelDims; for (auto i = 0UL; i < reductionDims.size(); i++) { newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i)); @@ -595,6 +595,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern { auto NOuterBlockSize = NDimPos.size() > 1 ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1 : cfg.NBlock; + // Outermost Numa loop + option.nestedTileSizes.emplace_back( + SmallVector{uint32_t(MFirstDim / 2)}); + option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp); + option.loopDim.emplace_back(SmallVector{MDimPos[0]}); // Outer option.nestedTileSizes.emplace_back(SmallVector{ MParallelBlockSize, NParallelBlockSize, KParallelBlockSize}); diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/Tiling.cpp index ea4d73722..471cd130c 100644 --- a/lib/gc/Transforms/Tiling.cpp +++ b/lib/gc/Transforms/Tiling.cpp @@ -782,6 +782,22 @@ FailureOr static tileLinalgOpImpl( return tileLinalgOpImpl(b, op, tileSizeVector, options); } +FailureOr +getTiledImplementationOnNuma(Operation *op, OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes) { + // Leave the `sizeBounds` value empty. That is only needed when the `sizes` + // specified could lead to out of bounds accesses. + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + SmallVector valuesToTile = linalgOp->getOperands(); + + SmallVector resultTensorTypes = + getTensorOutputTypes(linalgOp, valuesToTile); + Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, valuesToTile); + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + FailureOr tileAllUsingForall( RewriterBase &b, PartialReductionOpInterface op, ArrayRef threadNums, ArrayRef tileSizes, @@ -964,6 +980,16 @@ FailureOr tileAllUsingForall( // 4.b. Clone the op and update init operands. // We cannot use a IRMapping here because it can replace // different OpOperands with the same value. + bool isNumaLoop = false; + if (tileSizes.size() == iterationDomain.size()) { + for (auto [idx, tile] : llvm::enumerate(tileSizes)) { + if (idx == 0 && tileSizes[idx] == iterationDomain[idx].size) + break; + if (idx > 0 && tileSizes[idx] != iterationDomain[idx].size) + break; + isNumaLoop = true; + } + } Operation *clonedOp = b.clone(*op.getOperation()); b.modifyOpInPlace(clonedOp, [&]() { for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( @@ -974,17 +1000,32 @@ FailureOr tileAllUsingForall( }); // 5. Tile the cloned op and delete the clone. if (tileSizes.empty() || threadNums.empty()) { - FailureOr tilingResult = - cast(clonedOp).getTiledImplementation( - b, tiledOffsets, tiledSizes); - if (failed(tilingResult)) - return clonedOp->emitError("Failed to tile op: "); - if (tilingResult->tiledOps.size() != 1) { - return clonedOp->emitError("expected a single produced tiled op, got ") - << tilingResult->tiledOps.size(); + if (!isNumaLoop) { + FailureOr tilingResult = + cast(clonedOp).getTiledImplementation( + b, tiledOffsets, tiledSizes); + if (failed(tilingResult)) + return clonedOp->emitError("Failed to tile op: "); + if (tilingResult->tiledOps.size() != 1) { + return clonedOp->emitError( + "expected a single produced tiled op, got ") + << tilingResult->tiledOps.size(); + } + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; + } else { + FailureOr tilingResult = getTiledImplementationOnNuma( + cast(clonedOp), b, tiledOffsets, tiledSizes); + if (failed(tilingResult)) + return clonedOp->emitError("Failed to tile op: "); + if (tilingResult->tiledOps.size() != 1) { + return clonedOp->emitError( + "expected a single produced tiled op, got ") + << tilingResult->tiledOps.size(); + } + tiledOp = tilingResult->tiledOps.front(); + tilingResults = tilingResult->tiledValues; } - tiledOp = tilingResult->tiledOps.front(); - tilingResults = tilingResult->tiledValues; } else { LinalgTilingOptions options; FailureOr maybeTiled = tileLinalgOpImpl( @@ -1039,6 +1080,19 @@ FailureOr tileAllUsingForall( nonZeroDimIdx++; } } + if (auto attr = resultSizesRank[0].dyn_cast()) { + if (auto intAttr = attr.dyn_cast()) { + if (intAttr.getInt() == 16) + resultSizesRank[0] = b.getIndexAttr(32); + } + } else if (auto value = resultSizesRank[0].dyn_cast()) { + if (auto constantOp = value.getDefiningOp()) { + if (auto intAttr = constantOp.getValue().dyn_cast()) { + if (intAttr.getInt() == 16) + resultSizesRank[0] = b.getIndexAttr(32); + } + } + } if (hasReductionThreads) { for (auto [parallelDims, redVar] : llvm::zip(constantNewParallelDims, reductionInductionVars)) { From 2f29a49b56ed0e1956daffb9b74b81dd1cf2b663 Mon Sep 17 00:00:00 2001 From: ZhangYan Date: Tue, 9 Jul 2024 01:11:42 -0700 Subject: [PATCH 4/4] stash --- .../Transforms/DeepTileContractionNamedOp.cpp | 53 +++++++++++++++++++ lib/gc/Transforms/Tiling.cpp | 3 +- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp index b2a444abb..e8e8572c5 100644 --- a/lib/gc/Transforms/DeepTileContractionNamedOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionNamedOp.cpp @@ -471,7 +471,23 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, else tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile); } + llvm::outs() << "====================================\n"; + llvm::outs() << "tileSize: "; + for (auto t : tileSizes) { + llvm::outs() << t << ", "; + } + llvm::outs() << "\n"; + bool isLastForAllLoop = false; + for (auto [idx, tile] : llvm::enumerate(tileSizes)) { + if (isConstantIntValue(tile, 0)) { + break; + } + if (idx == tileSizes.size() - 1) + isLastForAllLoop = true; + } + llvm::outs() << "isLastForAllLoop: " << isLastForAllLoop << "\n"; + llvm::outs() << "====================================\n"; SmallVector loopRanges = cast(currentOp.getOperation()).getIterationDomain(b); OpBuilder::InsertionGuard guard(b); @@ -503,6 +519,43 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp, } } } + if (isLastForAllLoop) { + b.setInsertionPointAfter(currentOp); + mlir::easybuild::EasyBuilder eb{b, currentOp.getLoc()}; + auto cond = eb(true); + auto forAllOp = tilingResult->loops; + auto ifOp = b.create(currentOp.getLoc(), cond); + b.setInsertionPointToStart(&ifOp.getThenRegion().front()); + b.setInsertionPointAfter(ifOp); + // auto loc = currentOp.getLoc(); + // auto indexType = b.getIndexType(); + // auto c1 = b.create(loc, 0); + + // Get the argument to compare with + // Value arg2 = forAllOp.getRegion().getArgument( + // 0); // This assumes %arg2 is the first argument + // Value comparison = + // b.create(loc, arith::CmpIPredicate::eq, arg2, + // c1); + + // // Create the scf.if operation + // b.setInsertionPointToStart(&forAllOp.getRegion().front()); + // auto ifOp = b.create(loc, comparison, + // /*withElseRegion=*/false); + + // // Move the body of forallOp into the if true region + // b.inlineRegionBefore(forAllOp.getRegion(), ifOp.getThenRegion(), + // ifOp.getThenRegion().begin()); + + // // Now the body of forallOp is in the ifOp, we should clean up the + // // original region. + // forAllOp.getRegion().dropAllReferences(); + // // forAllOp.getRegion().clear(); + + // // Insert a yield operation to the scf.if operation's then region + // b.setInsertionPointToEnd(&ifOp.getThenRegion().front()); + // b.create(loc); + } } else if (auto tilingInterface = cast(currentOp.getOperation())) { auto tilingResult = linalg::tileToForallOpUsingTileSizes( diff --git a/lib/gc/Transforms/Tiling.cpp b/lib/gc/Transforms/Tiling.cpp index 471cd130c..9dd040fd5 100644 --- a/lib/gc/Transforms/Tiling.cpp +++ b/lib/gc/Transforms/Tiling.cpp @@ -987,7 +987,8 @@ FailureOr tileAllUsingForall( break; if (idx > 0 && tileSizes[idx] != iterationDomain[idx].size) break; - isNumaLoop = true; + if (idx == tileSizes.size() - 1) + isNumaLoop = true; } } Operation *clonedOp = b.clone(*op.getOperation());