From d29f03840516fb756683f64151462420ad86e6af Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 27 May 2024 16:15:48 +0800 Subject: [PATCH 01/66] init --- include/gc/Transforms/Passes.td | 30 + lib/gc/Transforms/CMakeLists.txt | 2 + lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 961 ++++++++++++++++++ lib/gc/Transforms/LowerTileVectorPass.cpp | 489 +++++++++ .../gc/transforms/cpu-vetor-distribution.mlir | 30 + test/gc/transforms/linalg-vectorization.mlir | 99 ++ 6 files changed, 1611 insertions(+) create mode 100644 lib/gc/Transforms/CPUPhysicalResigterPass.cpp create mode 100644 lib/gc/Transforms/LowerTileVectorPass.cpp create mode 100644 test/gc/transforms/cpu-vetor-distribution.mlir create mode 100644 test/gc/transforms/linalg-vectorization.mlir diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index d31baa5a7..24533be99 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -31,4 +31,34 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> { ]; } +def LowerToTileVector : Pass<"lower-to-tile-vector"> { + let summary = "Lower tensor to tile vector."; + let description = [{ + Lower tensor to tile vector form. + }]; + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::math::MathDialect", + "::mlir::arith::ArithDialect", + "::mlir::tensor::TensorDialect", + "::mlir::linalg::LinalgDialect", + "::mlir::vector::VectorDialect", + ]; +} + +def CPUPhysicalRegisterPass : Pass<"CPU-physical-register-pass", "func::FuncOp"> { + let summary = "Lower operation to cpu pysical register size."; + let description = [{ + Physical register size lowering pass. + }]; + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::math::MathDialect", + "::mlir::arith::ArithDialect", + "::mlir::tensor::TensorDialect", + "::mlir::vector::VectorDialect", + "::mlir::scf::SCFDialect", + ]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index d25c8a027..d9fc8b600 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -7,6 +7,8 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp TileNamed.cpp + LowerTileVectorPass.cpp + CPUPhysicalResigterPass.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp new file mode 100644 index 000000000..120d12c3e --- /dev/null +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -0,0 +1,961 @@ +//===- CPUPhysicalResigterPass.cpp.cpp - OneDNNGraph To Linalg +// Lowering -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_CPUPHYSICALREGISTERPASS +#include "gc/Transforms/Passes.h.inc" +namespace { +#define DEBUG_TYPE "lower-to-physical-register-pass" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +struct HardwareInfo { + bool favx512f = true; + bool favx2 = true; +} HW; + +bool isSpecialOp(Operation *op) { + return llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op); +} + +bool is_innermost_operation(Operation *op) { + bool inner_most = true; + op->walk([&inner_most](Operation *p) { + if (llvm::isa(p)) { + inner_most = false; + return WalkResult::interrupt(); + } + }); + return inner_most; +} + +int generateValidSteps(int steps, VectorType type) { + return type.getShape().back() >= steps ? steps : 1; +} + +// Get the maximum number of current data types that a register can hold +[[nodiscard]] int getDataTypeMAXSIMDLength(VectorType type) { + auto typebits = type.getElementTypeBitWidth(); + const int favx512bits = 512; + const int favx2bits = 256; + if (HW.favx512f) { + return generateValidSteps(favx512bits / typebits, type); + } else if (HW.favx2) { + return generateValidSteps(favx2bits / typebits, type); + } else { + // invalid + LDBG("Please check the hardware information."); + assert(false && "Invalid hardware."); + return -1; + } +} + +mlir::FailureOr getOperationVectorType(Operation *op) { + return TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + return transferWriteOp.getVectorType(); + }) + .Case([&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](arith::ConstantOp constantOp) { return failure(); }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (!op->getResults().empty()) { + auto t = op->getResultTypes().front().dyn_cast(); + if (t) { + return t; + } + } + return failure(); + }); +} + +// Filter out the operations that can be vectorized. We are only interested in +// operations that do not contain any for loops(innermost IR). +[[nodiscard]] bool filterOperation(Operation *op) { + if (!is_innermost_operation(op)) { + LDBG("Operation is not innermost" << *op << "\n"); + return false; + } + + // We are only interested about the operation in vector dialect + if (failed(getOperationVectorType(op))) { + LDBG("Operation is not in vector dialect" << *op << "\n"); + return false; + } + return true; +} + +// Since we rewrote transfer_read and transfer_write, the `permutationmap` must +// be changed. +void setOpVectorizationPermutationMap(Operation *op, IRRewriter &rewriter, + RankedTensorType tensorType) { + SmallVector affineExprs; + affineExprs.push_back(rewriter.getAffineDimExpr(tensorType.getRank() - 1)); + auto destAffineMap = AffineMap::get(tensorType.getRank(), 0, affineExprs, + rewriter.getContext()); + SmallVector inBounds(1, true); + if (mlir::isa(op)) { + auto transferWriteOp = mlir::dyn_cast(op); + transferWriteOp.setPermutationMap(destAffineMap); + transferWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); + } else if (mlir::isa(op)) { + auto transferReadOp = mlir::dyn_cast(op); + transferReadOp.setPermutationMap(destAffineMap); + transferReadOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); + } +} + +// scf.for yield helper function +void maybeYieldValue(OpBuilder &b, Location loc, ValueRange value) { + bool hasRetVal = !value.empty(); + if (hasRetVal) { + assert(!value.empty() && "Expected non-empty value"); + b.create(loc, value); + } else { + b.create(loc); + } +} + +// +void checkAndSetOperand( + Operation *op, const ValueRange &iterArgs, + const llvm::DenseMap &operandIdxMap, + const llvm::SmallVector &inductionVars, + const llvm::DenseMap &opPermuationMap) { + for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { + if (operandIdxMap.contains(opd)) { + op->setOperand(idx, iterArgs[operandIdxMap.at(opd)]); + } + } + int offset = isa(op) ? 2 : 1; + if (llvm::dyn_cast(op) || + llvm::dyn_cast(op)) { + assert(opPermuationMap.contains(op)); + auto permutationMap = opPermuationMap.at(op); + + auto dimExpr = permutationMap.getResults(); + for (auto [idx, x] : llvm::enumerate(dimExpr)) { + if (mlir::dyn_cast(x)) { + auto dim = mlir::dyn_cast(x).getPosition(); + op->setOperand(dim + offset, inductionVars[dim]); + } + } + } +} + +// TODO: need to rewrite reduce operation as a performance forms like +// graph-compiler v1 +scf::ForOp constructReductionNestedForOp( + OpBuilder &b, const Location &loc, const ValueRange &iterArgs, + const VectorType &type, const llvm::ArrayRef &dims, size_t idx, + std::queue &queue, const llvm::SetVector &resultSet, + llvm::SmallVector &inductionVars, + const llvm::DenseMap &operandIdxMap, + const llvm::SmallVector &rdDims, + const llvm::DenseMap &opPermuationMap) { + const int loop_step = getDataTypeMAXSIMDLength(type); + + // loop initialization variable + auto zero = + b.create(b.getUnknownLoc(), b.getIndexType(), + b.getIntegerAttr(b.getIndexType(), 0)); + auto forSteps = b.create( + b.getUnknownLoc(), b.getIndexType(), + b.getIntegerAttr(b.getIndexType(), + idx == dims.size() - 1 ? loop_step : 1)); + auto numIter = b.create( + b.getUnknownLoc(), b.getIndexType(), + b.getIntegerAttr(b.getIndexType(), dims[idx])); + + // Create a loop and move vectorized operation into loops. + auto forOp = b.create( + b.getUnknownLoc(), zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + + // inner most body of the loop + if (idx == dims.size() - 1) { + Operation *lastOperation = queue.front(); + while (!queue.empty()) { + auto x = queue.front(); + queue.pop(); + if (lastOperation == x) { + x->moveBefore(b.getBlock(), b.getBlock()->begin()); + } else { + x->moveAfter(lastOperation); + lastOperation = x; + } + // check operation type to set correct operand + checkAndSetOperand(x, loopState, operandIdxMap, inductionVars, + opPermuationMap); + } + maybeYieldValue(b, loc, resultSet.getArrayRef()); + } else { + + // outter loop + auto nxtFor = constructReductionNestedForOp( + b, loc, loopState, type, dims, idx + 1, queue, resultSet, + inductionVars, operandIdxMap, rdDims, opPermuationMap); + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); + return forOp; +} + +scf::ForOp constructNestedForOp( + OpBuilder &b, const Location &loc, const ValueRange &iterArgs, + const VectorType &type, const llvm::ArrayRef &dims, size_t idx, + std::queue &queue, const llvm::SetVector &resultSet, + llvm::SmallVector &inductionVars, + const llvm::DenseMap &operandIdxMap, + const llvm::DenseMap &opPermuationMap) { + const int loop_step = getDataTypeMAXSIMDLength(type); + + // loop initialization variable + auto zero = + b.create(b.getUnknownLoc(), b.getIndexType(), + b.getIntegerAttr(b.getIndexType(), 0)); + auto forSteps = b.create( + b.getUnknownLoc(), b.getIndexType(), + b.getIntegerAttr(b.getIndexType(), + idx == dims.size() - 1 ? loop_step : 1)); + auto numIter = b.create( + b.getUnknownLoc(), b.getIndexType(), + b.getIntegerAttr(b.getIndexType(), dims[idx])); + + // Create a loop and move vectorized operation into loops. + auto forOp = b.create( + b.getUnknownLoc(), zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + + // inner most body of the loop + if (idx == dims.size() - 1) { + Operation *lastOperation = queue.front(); + while (!queue.empty()) { + auto x = queue.front(); + queue.pop(); + if (lastOperation == x) { + x->moveBefore(b.getBlock(), b.getBlock()->begin()); + } else { + x->moveAfter(lastOperation); + lastOperation = x; + } + // check operation type to set correct operand + checkAndSetOperand(x, loopState, operandIdxMap, inductionVars, + opPermuationMap); + } + maybeYieldValue(b, loc, resultSet.getArrayRef()); + } else { + + // outter loop + auto nxtFor = constructNestedForOp( + b, loc, loopState, type, dims, idx + 1, queue, resultSet, + inductionVars, operandIdxMap, opPermuationMap); + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); + return forOp; +} + +bool isCompatibleVectorType(Operation *op1, Operation *op2) { + auto type1 = getOperationVectorType(op1); + auto type2 = getOperationVectorType(op2); + if (failed(type1) || failed(type2)) { + return false; + } + auto sp1 = type1.value(); + auto sp2 = type2.value(); + auto min_rank = std::min(sp1.getRank(), sp2.getRank()) - 1; + for (auto i = min_rank; i >= 0; i--) { + if (sp1.getDimSize(i) != sp2.getDimSize(i)) { + return false; + } + } + + return true; +} + +bool isNeedNewGroup(llvm::SmallVector, 8> &opGroups, + Operation *op) { + // 1. check previous operation + if (!opGroups.back().empty()) { + auto prevOp = opGroups.back().back(); + // not in the same operation + if (prevOp->getParentOp() != op->getParentOp()) { + return true; + } + // previous operation is a special operation + if (isSpecialOp(prevOp)) { + return true; + } + // previous operation vector type is not compatible with current operation + if (!isCompatibleVectorType(prevOp, op)) { + return true; + } + } + + // 2. check current operation + if (isSpecialOp(op)) { + return true; + } + return false; +} + +void addOperationToGroup( + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap, Operation *op, + llvm::SmallVector &groupsShapes) { + // + if (isNeedNewGroup(opGroups, op)) { + opGroups.emplace_back(std::queue()); + } + if (opGroups.size() != groupsShapes.size()) { + groupsShapes.emplace_back(getOperationVectorType(op).value()); + } + opGroups.back().push(op); + opGroupIndexMap[op] = opGroups.size() - 1; +} + +// We classify the operations we are interested in after filtering. Operations +// of in the same group have no data dependencies. Those operations can generate +// a same outter for loop. +void classifyOperations(func::FuncOp func, + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap, + llvm::SmallVector &groupsShapes) { + func->walk([&](Operation *op) { + TypeSwitch(op).Default([&](Operation *op) { + if (filterOperation(op)) { + addOperationToGroup(opGroups, opGroupIndexMap, op, groupsShapes); + } + }); + }); +} + +Value setOutGroupOperationOperandResult(Operation *op, + const VectorType &newOperandType) { + auto ret = TypeSwitch(op) + .Case([&](arith::ConstantOp constantOp) { + IRRewriter rewriter(op); + rewriter.setInsertionPointAfter(op); + Type resultElementType = newOperandType.getElementType(); + + Attribute initValueAttr; + if (isa(resultElementType)) + initValueAttr = FloatAttr::get(resultElementType, 0.0); + else + initValueAttr = IntegerAttr::get(resultElementType, 0); + auto cntOp = rewriter.create( + rewriter.getUnknownLoc(), + DenseElementsAttr::get(newOperandType, {initValueAttr})); + return cntOp->getResults()[0]; + }) + .Default([&](Operation *op) { return Value(); }); + return ret; +} + +void setOperationOperandResult( + Operation *op, const VectorType &newOperandType, + const llvm::DenseMap &opMap) { + for (auto [idx, x] : llvm::enumerate(op->getOperands())) { + if (x.getType().dyn_cast()) { + if (!opMap.contains(x.getDefiningOp())) { + auto result = setOutGroupOperationOperandResult(x.getDefiningOp(), + newOperandType); + op->setOperand(idx, result); + } else { + x.setType(newOperandType); + } + } + } + for (auto x : op->getResults()) { + if (x.getType().dyn_cast()) { + x.setType(newOperandType); + } + } +}; + +/// Rewrite the operations in the group to vectorized form. +void rewriteOperationAsVectorize( + const std::queue &groupOps, + llvm::DenseMap &opMap, IRRewriter &rewriter, + llvm::DenseMap &opPermuationMap) { + std::queue transformQueue(groupOps); + + auto getVectorzedType = [](Operation *op) -> VectorType { + // Check that the operation type can be broken + // down into a loop. + auto baseType = getOperationVectorType(op); + if (failed(baseType)) { + LDBG("Failed to get vector type for operation: " << *op << "\n"); + assert(false && "Failed to get vector type for operation"); + return VectorType(); + } + auto vectorizedType = baseType.value(); + const int loop_step = getDataTypeMAXSIMDLength(vectorizedType); + return VectorType::get({loop_step}, vectorizedType.getElementType()); + }; + + while (!transformQueue.empty()) { + auto op = transformQueue.front(); + transformQueue.pop(); + auto lowerResult = + TypeSwitch(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) { + auto newOperandType = getVectorzedType(transferWriteOp); + if (!isSpecialOp( + transferWriteOp->getOperand(0).getDefiningOp())) { + opPermuationMap.insert( + {transferWriteOp, transferWriteOp.getPermutationMap()}); + transferWriteOp->getOperand(0).setType(newOperandType); + setOpVectorizationPermutationMap( + transferWriteOp, rewriter, + transferWriteOp->getResult(0) + .getType() + .dyn_cast()); + } + + return success(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) { + auto newOperandType = getVectorzedType(transferReadOp); + auto users = transferReadOp->getUsers(); + bool isUserSpecial = false; + for (auto *opUse : users) { + if (isSpecialOp(opUse)) { + isUserSpecial = true; + break; + } + } + if (!isUserSpecial) { + opPermuationMap.insert( + {transferReadOp, transferReadOp.getPermutationMap()}); + transferReadOp->getResult(0).setType(newOperandType); + setOpVectorizationPermutationMap( + transferReadOp, rewriter, + transferReadOp.getSource() + .getType() + .dyn_cast()); + } + + return success(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + return success(); + }) + .Default([&](Operation *op) { + if (isSpecialOp(op)) { + return success(); + } + setOperationOperandResult(op, getVectorzedType(op), opMap); + return success(); + }); + if (failed(lowerResult)) { + LDBG("Failed to rewrite operation: " << *op << "\n"); + assert(false && "Failed to rewrite operation"); + } + } +} + +// analysis operation' operands are coming from which operation's result +void analysisOperaionOperandSource( + size_t idx, std::queue &grp, + llvm::DenseMap &opGroupIndexMap, + llvm::SmallVector, 8> &groupOperandNeedSet) { + auto tmpOpQueue(grp); + llvm::SetVector opOperands; + while (!tmpOpQueue.empty()) { + auto t = tmpOpQueue.front(); + for (auto x : t->getOperands()) { + // not in the same group + if (opGroupIndexMap.contains(x.getDefiningOp()) && + opGroupIndexMap[x.getDefiningOp()] != idx) { + groupOperandNeedSet[idx].insert(x); + } else { + groupOperandNeedSet[idx].insert(x); + } + } + tmpOpQueue.pop(); + } +} + +Operation *createTensorEmptyBefore(Operation *op) { + auto rtType = op->getResultTypes()[0].dyn_cast(); + IRRewriter reWriter(op); + + SmallVector shapes; + SmallVector dynDims; + for (unsigned i = 0; i < rtType.getRank(); i++) { + shapes.push_back(rtType.getDimSize(i)); + if (rtType.isDynamicDim(i)) + dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), + op->getResult(0), i)); + } + return reWriter.create(reWriter.getUnknownLoc(), + rtType.getShape(), + rtType.getElementType(), dynDims); +} + +Operation * +createTransferReadOpBefore(Operation *op, const Value &operand, + vector::TransferReadOp *srcReadOp = nullptr) { + auto operandType = operand.getType().dyn_cast(); + + IRRewriter rewriter(op); + auto zero = + rewriter.create(rewriter.getUnknownLoc(), 0); + auto padValue = rewriter.create( + rewriter.getUnknownLoc(), + rewriter.getZeroAttr(operandType.getElementType())); + + if (srcReadOp) { + auto resultType = srcReadOp->getType().dyn_cast(); + SmallVector inBoundsVal(resultType.getRank(), true); + auto srcReadOpAffineMap = srcReadOp->getPermutationMap(); + // result of read operation should be same as operand + auto t = rewriter.create( + rewriter.getUnknownLoc(), + /*vectorType=*/ + VectorType::get(resultType.getShape(), resultType.getElementType()), + /*source=*/operand, + /*indices=*/SmallVector(operandType.getRank(), zero), + /**affinemap*/ srcReadOpAffineMap, + /*inBounds=*/inBoundsVal); + + return t; + } else { + SmallVector inBoundsVal(operandType.getRank(), true); + auto t = rewriter.create( + rewriter.getUnknownLoc(), + /*vectorType=*/ + VectorType::get(operandType.getShape(), operandType.getElementType()), + /*source=*/operand, + /*indices=*/SmallVector(operandType.getRank(), zero), + /**affinemap*/ padValue, + /*inBounds=*/inBoundsVal); + return t; + } +} + +Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { + auto rtType = op->getResultTypes()[0].dyn_cast(); + auto rank = rtType.getRank(); + auto dstType = dest.getType().dyn_cast(); + IRRewriter reWriter(op); + + auto zero = + reWriter.create(reWriter.getUnknownLoc(), 0); + + reWriter.setInsertionPointAfter(op); + SmallVector inBoundsVal(rank, true); + + SmallVector shapes; + SmallVector dynDims; + for (unsigned i = 0; i < rtType.getRank(); i++) { + shapes.push_back(rtType.getDimSize(i)); + if (rtType.isDynamicDim(i)) { + dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), + op->getResult(0), i)); + } + } + return reWriter.create( + reWriter.getUnknownLoc(), + /*vector=*/op->getResult(0), + /*source=*/dest, + /*indices=*/SmallVector(dstType.getRank(), zero), + /*inBounds=*/inBoundsVal); +} + +// canonicalizing operation as tensor empty and transfer write the operation +// result into the empty tensor +[[nodiscard]] std::pair +canonicalizeSourceOperation(Operation *op) { + auto emtpyOp = createTensorEmptyBefore(op); + auto writeOp = createTransferWriteOpAfter(op, emtpyOp->getResults()[0]); + return std::make_pair(emtpyOp->getResults()[0], writeOp->getResults()[0]); +} + +[[nodiscard]] Value +canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, + size_t operandIdx, + vector::TransferReadOp *srcReadOp = nullptr) { + // transfer_read operation + auto readOp = createTransferReadOpBefore(op, transferReadOperand, srcReadOp); + op->setOperand(operandIdx, readOp->getResults()[0]); + return readOp->getResults()[0]; +} + +mlir::FailureOr getOperationDestnationOperand(Operation *op) { + return llvm::TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) { + LDBG(" DPS operation : " << *op << "\n"); + return transferWriteOp->getOperand(1); + }) + .Case([&](vector::TransferReadOp transferReadOp) { + LDBG(" DPS operation : " << *op << "\n"); + return transferReadOp->getOperand(0); + }) + .Default([&](Operation *op) { + LDBG("Try to get not DPS operation inits: " << *op << "\n"); + return failure(); + }); +} + +// analysis operations of current group need which operation's result value +void analysisGroupOperationOperands( + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap, + llvm::SmallVector, 8> &groupOperandNeedSet) { + + for (auto [idx, grp] : enumerate(opGroups)) { + analysisOperaionOperandSource(idx, grp, opGroupIndexMap, + groupOperandNeedSet); + } +} + +// TODO: need to rewrite reduce +// llvm::SmallVector & +// getReductionDims(vector::MultiDimReductionOp &reductionOp, +// llvm::SmallVector &rdDims) { +// auto rdDimsAttr = reductionOp.getReductionDims().getValue(); +// for (auto x : rdDimsAttr) { +// rdDims.emplace_back(x.cast().getInt()); +// } +// return rdDims; +// } + +void updateOpOperandResultInGroups( + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap, size_t opGid, + Operation *op, Value &init, const Value &result = Value()) { + auto tmpOpQueue(opGroups[opGid]); + std::queue newOpQueue; + while (!tmpOpQueue.empty()) { + auto curOp = tmpOpQueue.front(); + tmpOpQueue.pop(); + if (curOp == op) { + if (!failed(getOperationVectorType(init.getDefiningOp()))) { + newOpQueue.push(init.getDefiningOp()); + opGroupIndexMap[init.getDefiningOp()] = opGid; + } + + newOpQueue.push(op); + + if (result && !failed(getOperationVectorType(result.getDefiningOp()))) { + newOpQueue.push(result.getDefiningOp()); + opGroupIndexMap[result.getDefiningOp()] = opGid; + } + } else { + newOpQueue.push(curOp); + } + } + opGroups[opGid] = newOpQueue; +} + +// analysis operation result of current group whether needed by other +// operation which out of current group +void analysisGroupOperationResults( + func::FuncOp &func, llvm::SmallVector, 8> &opGroups, + IRMapping &mapOpResultToYield, + llvm::DenseMap &opGroupIndexMap, + llvm::SmallVector, 8> &groupResultYeildSet, + llvm::SmallVector, 8> &groupOpDestination) { + llvm::DenseMap> srcOpCanoniclizedMap; + + func.walk([&](Operation *op) { + for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { + auto sourceOp = opd.getDefiningOp(); + if (opGroupIndexMap.contains(sourceOp)) { + auto sourceOpGid = opGroupIndexMap[sourceOp]; + // + bool notInSameGroup = + opGroupIndexMap.contains(op) && sourceOpGid != opGroupIndexMap[op]; + bool outOfGroup = !opGroupIndexMap.contains(op); + if (notInSameGroup or outOfGroup) { + // update init iterargs + auto dstRet = getOperationDestnationOperand(sourceOp); + if (failed(dstRet)) { + if (!srcOpCanoniclizedMap.contains(sourceOp)) { + auto [init, result] = canonicalizeSourceOperation(sourceOp); + srcOpCanoniclizedMap.insert({sourceOp, {init, result}}); + updateOpOperandResultInGroups(opGroups, opGroupIndexMap, + sourceOpGid, sourceOp, init, + result); + groupOpDestination[sourceOpGid].insert(init); + groupResultYeildSet[sourceOpGid].insert(result); + mapOpResultToYield.map(result, result); + } + + auto opInit = canonicalizeCurrentOperation( + op, srcOpCanoniclizedMap[sourceOp].second, idx); + updateOpOperandResultInGroups(opGroups, opGroupIndexMap, + opGroupIndexMap[op], op, opInit); + + } else { + if (mlir::isa(sourceOp)) { + auto transferReadOp = + mlir::dyn_cast(sourceOp); + auto opInit = canonicalizeCurrentOperation(op, dstRet.value(), + idx, &transferReadOp); + updateOpOperandResultInGroups(opGroups, opGroupIndexMap, + opGroupIndexMap[op], op, opInit); + + } else { + groupOpDestination[sourceOpGid].insert(dstRet.value()); + groupResultYeildSet[sourceOpGid].insert(opd); + + // just map to it self, placeholder + mapOpResultToYield.map(opd, opd); + } + } + } + } + } + }); + LDBG("Complete analysis group operation results\n"); +} + +void analysisGroupOperaionOperandsResults( + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap, + llvm::SmallVector, 8> &groupOperandNeedSet, + func::FuncOp &func, + llvm::SmallVector, 8> &groupResultYeildSet, + IRMapping &mapOpResultToYield, + llvm::SmallVector, 8> &groupOpDestination) { + // Operands + analysisGroupOperationOperands(opGroups, opGroupIndexMap, + groupOperandNeedSet); + // Results + analysisGroupOperationResults(func, opGroups, mapOpResultToYield, + opGroupIndexMap, groupResultYeildSet, + groupOpDestination); +} + +mlir::FailureOr generateVectorizedForLoop( + IRRewriter &rewriter, const llvm::SetVector &resultSet, + const llvm::SetVector &dstOperandSet, const VectorType &vectorType, + std::queue &queue, + const llvm::DenseMap &opPermuationMap) { + assert(!resultSet.empty() && "Expected non-empty value"); + // prepare for loop iterargs + llvm::SmallVector operands; + llvm::DenseMap operandIdxMap; + for (auto [idx, x] : llvm::enumerate(dstOperandSet)) { + operands.emplace_back(x); + operandIdxMap[x] = operands.size() - 1; + } + ValueRange iterArgs(operands); + auto shapes = vectorType.getShape(); + llvm::SmallVector inductionVars; + // TODO: special operation process + bool isOpSpecial = false; + std::queue tmpQ(queue); + // temporary for special operation generation + while (!tmpQ.empty()) { + if (isSpecialOp(tmpQ.front())) { + isOpSpecial = true; + break; + } + tmpQ.pop(); + } + if (isOpSpecial) { + return failure(); + } + // generate for loop + auto forOp = constructNestedForOp( + rewriter, rewriter.getUnknownLoc(), iterArgs, vectorType, shapes, 0, + queue, resultSet, inductionVars, operandIdxMap, opPermuationMap); + return forOp; +} + +void updateLoopResultUses( + const size_t groupIdx, const size_t groupSize, + llvm::SmallVector, 8> &groupResultYeildSet, + const func::FuncOp &func, scf::ForOp *forOp, + IRMapping &mapOpResultToYield) { + // update loop result uses + for (auto [retIdx, rt] : llvm::enumerate(groupResultYeildSet[groupIdx])) { + mapOpResultToYield.map(rt, forOp->getResult(retIdx)); + } + auto currentIdx = groupIdx; + func->walk([&](Operation *op) { + for (auto [opdIdx, opd] : llvm::enumerate(op->getOperands())) { + if (groupResultYeildSet[currentIdx].contains(opd) && + opd.getDefiningOp() != op && + opd.getDefiningOp()->getBlock() != op->getBlock()) { + op->setOperand(opdIdx, mapOpResultToYield.getValueMap().at(opd)); + } + } + }); +} + +void generateGroupOpVectorizedIR( + std::queue &grp, const size_t idx, + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap, + llvm::SmallVector &groupsShapes, + llvm::SmallVector, 8> &groupResultYeildSet, + llvm::SmallVector, 8> &groupOpDestination, + IRMapping &mapOpResultToYield, func::FuncOp &func, + llvm::DenseMap &opPermuationMap) { + if (grp.empty()) { + LDBG("Current operation Group is empty."); + return; + } + IRRewriter rewriter(grp.back()); + rewriter.setInsertionPointAfter(grp.back()); + // 1. Rewrite operation as vectorized form + rewriteOperationAsVectorize(opGroups[idx], opGroupIndexMap, rewriter, + opPermuationMap); + // 2. Generate loop + auto forOp = generateVectorizedForLoop( + rewriter, groupResultYeildSet[idx], groupOpDestination[idx], + groupsShapes[idx], opGroups[idx], opPermuationMap); + // special operation do not need to change anything + if (failed(forOp)) { + return; + } + // 3 Update loop result uses + updateLoopResultUses(idx, opGroups.size(), groupResultYeildSet, func, + &forOp.value(), mapOpResultToYield); +} + +/// Pass that lower to tile vector. +struct CPUPhysicalRegisterPass + : public impl::CPUPhysicalRegisterPassBase { + + void runOnOperation() final { + // + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + auto func = getOperation(); + + // 1. Classify operaions: + // classify the operations into : + // a. reorder, transpose. Reorder(or transpose) dim may bring data + // dependency. + // b. elemenwise. Those operations can be fused into a common for loop. + // c. broadcast. Need to analysis broadcast dim and the data + // dependency. + // d. reduction. Need to analysis broadcast dim and the + // data dependency. + + // Using queue to store the operation order. In order to ensure that + // subsequent moves to the operation will not cause semantic changes. + llvm::SmallVector, 8> opGroups; + llvm::SmallVector groupsShapes; + // dummy + opGroups.emplace_back(std::queue()); + + // query current operation in which group, return group index + llvm::DenseMap opGroupIndexMap; + classifyOperations(func, opGroups, opGroupIndexMap, groupsShapes); + + // 2. Analysis the operation's operands and results + // We need to analyze which operation results are needed by other + // operations, and we need to pass these results correctly. Mapping the + // operation result value to forloop yeild result value. We can replace the + // operation operand as: map(operand, forloop yield result) -> operand = + // loop yield result We put all the operation result into this map. + + // 2.a. Find what results should be generated by current group for + // using as operands to other operations? + + // Traverse all operations. If the operand of operations in other groups or + // outside the group is the result of the current group operation, then the + // current operation needs to generate a result. We use `setvector` to save + // the results that need to be generated by the current group. + + // 2.b. What operands are needed to find in the current group, and where + // can they be obtained ? + + // Thanks to 2.a, we get the result generated by the operations of + // each group, and this result will use `for loop yield` to generate a + // new result. Since the scope of the parent block of mlir is covered + // the current operation, the current operation does not need to pass these + // `for loop results` to the `iter args` of the required `for loop`. It + // only needs to replace the operand of the current operation with the + // corresponding `for loop yield result`. + + // However, for some operations that are not DPS, we need to canonicalize + // them. Canonicalization means that the operand of this operation is a + // vector but we can't get this vector due to it locates in another block + // which has a different scope. Therefore, it is necessary to write the + // vector results into a temporary tensor to save it. Then the vector needs + // to be read from the tensor before the current operation operate on it. + // Therefore, `empty tensor`, `transfer_write` and `transfer_read` need to + // be inserted at target place. + + llvm::SmallVector, 8> groupOperandNeedSet( + opGroups.size(), llvm::SetVector()), + groupResultYeildSet(opGroups.size(), llvm::SetVector()), + groupOpDestination(opGroups.size(), llvm::SetVector()); + // Query groupResultYeildSet to map operaion result value to scf.yield + // result value. + IRMapping mapOpResultToYield; + analysisGroupOperaionOperandsResults( + opGroups, opGroupIndexMap, groupOperandNeedSet, func, + groupResultYeildSet, mapOpResultToYield, groupOpDestination); + + OpBuilder builder(ctx); + // store read and write operations permutation maps in order to convenient + // to replace loop induction var + llvm::DenseMap opPermuationMap; + + // 3.Generate vectorized IR for each operation group + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + + generateGroupOpVectorizedIR(grp, idx, opGroups, opGroupIndexMap, + groupsShapes, groupResultYeildSet, + groupOpDestination, mapOpResultToYield, func, + opPermuationMap); + } + } +}; +} // namespace + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp new file mode 100644 index 000000000..900b685f7 --- /dev/null +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -0,0 +1,489 @@ +//===- LowerTileVectorPass.cpp.cpp - OneDNNGraph To Linalg +// Lowering -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Casting.h" + +namespace mlir { +namespace gc { + +#define GEN_PASS_DEF_LOWERTOTILEVECTOR +#include "gc/Transforms/Passes.h.inc" +namespace { +#define DEBUG_TYPE "lower-to-tile-vector-pass" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +bool is_innermost_ir(Operation *op) { + bool inner_most = true; + op->walk([&inner_most](Operation *p) { + if (llvm::isa(p)) { + inner_most = false; + return WalkResult::interrupt(); + } + }); + return inner_most; +} + +/// Need to check if the reassociation are static/constant. +LogicalResult lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp) { + + // if (llvm::any_of(expandOp.getReassociation(), [](ArrayAttr res) { + // if (llvm::any_of(res, [](Attribute x) { + // return !getConstantIntValue(x).has_value(); + // })) { + // return false; + // } + // return true; + // })) { + // LDBG("Reassociation must be constant: " << expandOp << "\n"); + // return failure(); + // } + + return success(); +} + +LogicalResult lowerTargetOpPrecondition(Operation *op, + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, + bool vectorizeNDExtract, + bool flatten1DDepthwiseConv) { + + return TypeSwitch(op) + .Case([&](auto expandShapeOp) { + return lowerExpandOpPrecondition(expandShapeOp); + }) + .Case( + [&](auto collapseShapeOp) { return success(); }) + .Case([&](auto collapseShapeOp) { return success(); }) + .Case([&](auto concatOp) { return success(); }) + .Default([](auto) { return failure(); }); +} + +/// Create a TransferReadOp from `source` with static shape `readShape`. +Value createTransferRead(OpBuilder &builder, Location loc, Value source, + ArrayRef readShape, Value padValue) { + assert(llvm::none_of(readShape, + [](int64_t s) { return s == ShapedType::kDynamic; })); + assert(source && " source null."); + auto sourceShape = dyn_cast(source.getType()).getShape(); + assert(sourceShape.size() == readShape.size()); + auto vectorType = VectorType::get(readShape, padValue.getType()); + int64_t readRank = readShape.size(); + auto zero = builder.create(loc, 0); + SmallVector inBoundsVal(readRank, true); + auto transferReadOp = builder.create( + loc, + /*vectorType=*/vectorType, + /*source=*/source, + /*indices=*/SmallVector(readRank, zero), + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); + + if (llvm::equal(readShape, sourceShape)) { + return transferReadOp; + } else { + assert(false && "wrong shape."); + } +} + +/// Given an input, the mixed destSizes, and the vector sizes for vectorization, +/// create an empty destination tensor and create a TransferWriteOp from the +/// input to the empty tensor. +Operation *createTransferWrite(OpBuilder &builder, Location loc, Value input, + SmallVector destSizes, + ArrayRef inputVectorSizes) { + auto inputType = cast(input.getType()); + Value dest = builder.create(loc, destSizes, + inputType.getElementType()); + int64_t rank = cast(dest.getType()).getRank(); + auto zero = builder.create(loc, 0); + Operation *write = builder.create( + loc, + /*vector=*/input, + /*source=*/dest, + /*indices=*/SmallVector(rank, zero), + /*inBounds=*/SmallVector(rank, true)); + auto destShape = cast(dest.getType()).getShape(); + assert(llvm::none_of( + destShape.drop_front(inputVectorSizes.size()), + [](int64_t size) { return size == ShapedType::kDynamic; }) && + "Only dims aligned with inputVectorSizes may be dynamic"); + return write; +} + +/// Vectorize a `tensor::expandshape` to these 3 Ops: +/// Vector::TransferReadOp - Reads a vector from the source tensor +/// ShapeCastOp - Reshape the data based on the target. +/// vector::TransferWriteOp. - Write the result vector back to the destination +/// tensor +template +LogicalResult lowerTensorExpandShapeOp(RewriterBase &rewriter, T expandShapeOp, + ArrayRef inputVectorSizes, + SmallVectorImpl &newResults) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(expandShapeOp); + + RankedTensorType expandShapeTensorType = expandShapeOp.getSrcType(); + + SmallVector readMaskShape(inputVectorSizes.begin(), + inputVectorSizes.end()); + ArrayRef sourceShape = expandShapeTensorType.getShape(); + ArrayRef resultShape = expandShapeOp.getResultType().getShape(); + + readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(), + sourceShape.end()); + + ReifiedRankedShapedTypeDims reifiedRetShapes; + LogicalResult status = + cast(expandShapeOp.getOperation()) + .reifyResultShapes(rewriter, reifiedRetShapes); + if (status.failed()) { + LDBG("Unable to reify result shapes of " << expandShapeOp << "\n"); + return failure(); + } + Location loc = expandShapeOp->getLoc(); + + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(expandShapeTensorType.getElementType())); + + // Read result, mask if necessary. If transferReadOp shape is not equal + // to shape of source, then a mask is necessary. + Value readResult = createTransferRead( + rewriter, loc, expandShapeOp.getSrc(), + ArrayRef(readMaskShape.begin(), readMaskShape.end()), padValue); + + auto resultVectorType = + VectorType::get(resultShape, expandShapeTensorType.getElementType()); + vector::ShapeCastOp shapeCastOp = + rewriter.create(loc, resultVectorType, readResult); + + SmallVector writeMaskShape( + expandShapeOp.getResultType().hasStaticShape() + ? inputVectorSizes + : shapeCastOp.getResultVectorType().getShape()); + Operation *write = createTransferWrite(rewriter, loc, shapeCastOp.getResult(), + reifiedRetShapes[0], writeMaskShape); + newResults.push_back(write->getResult(0)); + return success(); +} + +/// Vectorize a `tensor::bitcast` to these 3 Ops: +/// vector::TransferReadOp - Reads a vector from the source tensor +/// vector.Bitcast - Bitcast the data based on the target. +/// vector::TransferWriteOp. - Write the result vector back to the destination +/// tensor +LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, + tensor::BitcastOp bitCastOp, + ArrayRef inputVectorSizes, + SmallVectorImpl &newResults) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(bitCastOp); + + auto sourceType = bitCastOp.getSource().getType(); + auto sourceShape = sourceType.getShape(); + auto resultType = bitCastOp.getResult().getType(); + auto resultShape = resultType.getShape(); + + SmallVector readMaskShape(inputVectorSizes.begin(), + inputVectorSizes.end()); + + readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(), + sourceShape.end()); + + Location loc = bitCastOp->getLoc(); + + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(sourceType.getElementType())); + + // Read result, mask if necessary. If transferReadOp shape is not equal + // to shape of source, then a mask is necessary. + Value readResult = createTransferRead( + rewriter, loc, bitCastOp->getOperand(0), + ArrayRef(readMaskShape.begin(), readMaskShape.end()), padValue); + + auto resultVectorType = + VectorType::get(resultShape, resultType.getElementType()); + vector::BitCastOp vectorbitCastOp = + rewriter.create(loc, resultVectorType, readResult); + + Value zero = rewriter.create(loc, 0); + SmallVector indices(resultType.getRank(), zero); + Value dest = rewriter.create(loc, resultShape, + resultType.getElementType()); + Operation *write = rewriter.create( + loc, vectorbitCastOp, dest, indices, + rewriter.getMultiDimIdentityMap(resultType.getRank())); + newResults.push_back(write->getResults()[0]); + return success(); +} + +/// Vectorize a `tensor::concat` to these 3 Ops: +/// Tensor::EmptyOp - The result tensor. +/// Vector::TransferWriteOp - Write the result vector back to the destination +/// tensor. +/// Vector::TransferWriteOp - Write the result vector back to the destination +/// tensor. +LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, + tensor::ConcatOp concatOp, + ArrayRef inputVectorSizes, + SmallVectorImpl &newResults) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(concatOp); + + Location loc = concatOp.getLoc(); + FailureOr dest = + tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0)); + if (failed(dest)) + return failure(); + + auto empty = dest->getDefiningOp(); + if (!empty) + return failure(); + + // Compute the partial sums for the slice offsets. + + int64_t dim = concatOp.getDim(); + Value dimValue = + rewriter.create(loc, rewriter.getIndexAttr(dim)); + + int64_t rank = concatOp.getResultType().getRank(); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + + // Construct the chain of insert_slice ops into the destination. + Value result = *dest; + Value previous_offset = rewriter.create(loc, 0); + for (auto input : concatOp.getInputs()) { + + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, input); + SmallVector readMaskShape(inputVectorSizes.begin(), + inputVectorSizes.end()); + auto inputType = llvm::cast(input.getType()); + auto sourceShape = inputType.getShape(); + + readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(), + sourceShape.end()); + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(inputType.getElementType())); + Value readResult = createTransferRead( + rewriter, loc, input, + ArrayRef(readMaskShape.begin(), readMaskShape.end()), + padValue); + Value zero = rewriter.create(loc, 0); + SmallVector indices(rank, zero); + indices[dim] = previous_offset; + result = rewriter + .create( + loc, readResult, result, indices, + rewriter.getMultiDimIdentityMap(rank)) + ->getResults()[0]; + auto dimOp = rewriter.create(loc, input, dimValue); + previous_offset = + rewriter.create(loc, dimOp, previous_offset); + } + + newResults.push_back(result); + return success(); +} + +/// Emit a suitable vector form for an operation. If provided, +/// `inputVectorSizes` are used to vectorize this operation. +/// `inputVectorSizes` must match the rank of the iteration space of the +/// operation and the input vector sizes must be greater than or equal to +/// their counterpart iteration space sizes, if static. `inputVectorShapes` +/// also allows the vectorization of operations with dynamic shapes. +LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op, + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, + bool vectorizeNDExtract, + bool flatten1DDepthwiseConv) { + LDBG("Attempting to vectorize:\n" << *op << "\n"); + LDBG("Input vector sizes: "); + LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG("Input scalable vector dims: "); + LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); + + if (failed(lowerTargetOpPrecondition(op, inputVectorSizes, + inputScalableVecDims, vectorizeNDExtract, + flatten1DDepthwiseConv))) { + LDBG("Vectorization pre-conditions failed\n"); + return failure(); + } + + SmallVector results; + auto lowerResult = + TypeSwitch(op) + .Case([&](auto expandShapeOp) { + return lowerTensorExpandShapeOp(rewriter, expandShapeOp, + inputVectorSizes, results); + }) + .Case([&](auto collapseShapeOp) { + return lowerTensorExpandShapeOp(rewriter, collapseShapeOp, + inputVectorSizes, results); + }) + .Case([&](auto bitCastOp) { + return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes, + results); + }) + .Case([&](auto concatOp) { + return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes, + results); + }) + .Default([](auto) { return failure(); }); + + if (failed(lowerResult)) { + LDBG("Lower failed\n"); + return failure(); + } + + if (!results.empty()) + rewriter.replaceOp(op, results); + else + rewriter.eraseOp(op); + + return success(); +} + +bool is_required_tensorOp(Operation *operation) { + return llvm::isa(operation) || + llvm::isa(operation) || + llvm::isa(operation) || + llvm::isa(operation); +} + +struct LinalgConvertTileVectorPass : public RewritePattern { + + explicit LinalgConvertTileVectorPass(MLIRContext *context, + bool vectorizeExtract = false, + bool flatten1DDepthwiseConv = false) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + + auto linalgOp = llvm::dyn_cast(op); + if (!linalgOp || !is_innermost_ir(op)) + return rewriter.notifyMatchFailure(op, "Not expected operations."); + + return linalg::vectorize(rewriter, op, /*inputVectorSizes=*/{}, + /*scalableVecDims=*/{}, true, false); + } +}; + +struct TensorPackConvertVectorPass : public RewritePattern { + + explicit TensorPackConvertVectorPass(MLIRContext *context, + bool vectorizeExtract = false, + bool flatten1DDepthwiseConv = false) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + + tensor::PackOp tensorPackOp = dyn_cast(op); + if (!tensorPackOp || !is_innermost_ir(op)) + return rewriter.notifyMatchFailure(op, "Not expected operations."); + + return linalg::vectorize(rewriter, op, /*inputVectorSizes=*/{}, + /*scalableVecDims=*/{}, true, false); + } +}; + +struct TensorUnpackConvertVectorPass : public RewritePattern { + + explicit TensorUnpackConvertVectorPass(MLIRContext *context, + bool vectorizeExtract = false, + bool flatten1DDepthwiseConv = false) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + + tensor::UnPackOp tensorUnPackOp = dyn_cast(op); + + if (!tensorUnPackOp || !is_innermost_ir(op)) + return rewriter.notifyMatchFailure(op, "Not expected operations."); + + Value resultValue = op->getResult(0); + auto resultTy = dyn_cast(resultValue.getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "expected ranked tensor type"); + + llvm::ArrayRef inputShape = resultTy.getShape(); + std::vector targetVectorSizes = inputShape.vec(); + llvm::SmallVector targetVecDims(inputShape.size(), false); + return linalg::vectorize(rewriter, op, + /*inputVectorSizes=*/targetVectorSizes, + /*scalableVecDims=*/targetVecDims, true, false); + } +}; + +struct TensorOpConvertVectorPass : public RewritePattern { + + explicit TensorOpConvertVectorPass(MLIRContext *context, + bool vectorizeExtract = false, + bool flatten1DDepthwiseConv = false) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + + bool is_target = is_required_tensorOp(op); + if (!is_target || !is_innermost_ir(op)) + return rewriter.notifyMatchFailure(op, "Not expected operations."); + + return convert2TargetOperation(rewriter, op, /*inputVectorSizes=*/{}, + /*scalableVecDims=*/{}, true, false); + } +}; + +/// Pass that lower to tile vector. +void populateLowerToTileVectorPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} + +struct LowerTileVectorPass + : public impl::LowerToTileVectorBase { + void runOnOperation() final { + // + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + tensor::populateRewriteAsConstantPatterns(patterns); + tensor::populateReassociativeReshapeFoldingPatterns(patterns); + populateLowerToTileVectorPatterns(patterns); + linalg::populatePadOpVectorizationPatterns(patterns); + tensor::populateFoldTensorSubsetOpPatterns(patterns); + tensor::populateFoldTensorEmptyPatterns(patterns, true); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr createLowerTileVectorPass() { + return std::make_unique(); +} +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir new file mode 100644 index 000000000..69f5ed058 --- /dev/null +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -0,0 +1,30 @@ +// RUN: gc-opt --split-input-file --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all -- %s + +// CHECK-LABEL: func @add_tensor +func.func @add_tensor_test0(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { + %0 = tensor.empty() : tensor<4x8x16xf32> + %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + %2 = linalg.add ins(%1, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + return %2 : tensor<4x8x16xf32> +} + +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{elementwise binary}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} diff --git a/test/gc/transforms/linalg-vectorization.mlir b/test/gc/transforms/linalg-vectorization.mlir new file mode 100644 index 000000000..a3b68a92d --- /dev/null +++ b/test/gc/transforms/linalg-vectorization.mlir @@ -0,0 +1,99 @@ +// RUN: gc-opt --split-input-file -pass-pipeline='builtin.module(func.func(lower-to-tile-vector))' --mlir-print-ir-after-all -- %s + +// CHECK-LABEL: func @add_tensor +func.func @add_tensor_test0(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { + %0 = tensor.empty() : tensor<4x8x16xf32> + %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> + return %1 : tensor<4x8x16xf32> +} + +func.func @add_tensor_test1(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<1x8x8xf32> { + %0 = tensor.empty() : tensor<1x8x8xf32> + %1 = tensor.extract_slice %arg0[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<4x8x16xf32> to tensor<1x8x8xf32> + %2 = tensor.extract_slice %arg1[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<4x8x16xf32> to tensor<1x8x8xf32> + %3 = linalg.add ins(%1, %2 : tensor<1x8x8xf32>, tensor<1x8x8xf32>) outs(%0: tensor<1x8x8xf32>) -> tensor<1x8x8xf32> + return %3 : tensor<1x8x8xf32> +} + +func.func @add_tensor_pack_test2(%arg0: tensor<4x16x16xf32>, %arg1: tensor<4x16x16xf32>) -> tensor<4x4x4x4x4xf32> { + %0 = tensor.empty() : tensor<4x4x4x4x4xf32> + %1 = tensor.empty() : tensor<4x4x4x4x4xf32> + %2 = tensor.pack %arg0 outer_dims_perm = [1, 0, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 4] into %0 : tensor<4x16x16xf32> -> tensor<4x4x4x4x4xf32> + %3 = tensor.pack %arg1 outer_dims_perm = [1, 0, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 4] into %1 : tensor<4x16x16xf32> -> tensor<4x4x4x4x4xf32> + %4 = tensor.empty() : tensor<4x4x4x4x4xf32> + %6 = linalg.add ins(%2, %3 : tensor<4x4x4x4x4xf32>, tensor<4x4x4x4x4xf32>) outs(%4: tensor<4x4x4x4x4xf32>) -> tensor<4x4x4x4x4xf32> + return %6 : tensor<4x4x4x4x4xf32> +} + +func.func @add_tensor_pad_test3(%arg0: tensor<4x16x15xf32>, %arg1: tensor<4x16x15xf32>) -> tensor<4x16x16xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.pad %arg0 low[0, 0, 0] high[0, 0, 1] { + ^bb0(%arg2: index, %arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<4x16x15xf32> to tensor<4x16x16xf32> + %1 = tensor.pad %arg1 low[0, 0, 0] high[0, 0, 1] { + ^bb0(%arg5: index, %arg6: index, %arg7: index): + tensor.yield %cst : f32 + } : tensor<4x16x15xf32> to tensor<4x16x16xf32> + %2 = tensor.empty() : tensor<4x16x16xf32> + %3 = linalg.add ins(%0, %1 : tensor<4x16x16xf32>, tensor<4x16x16xf32>) outs(%2: tensor<4x16x16xf32>) -> tensor<4x16x16xf32> + return %3 : tensor<4x16x16xf32> +} + +func.func @add_tensor_test4(%arg0: tensor<12x2x56x56x32xf32>, %arg1: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> { + %0 = tensor.empty() : tensor<12x56x56x64xf32> + %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> + %2 = tensor.empty() : tensor<12x56x56x64xf32> + %3 = tensor.unpack %arg1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %2 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> + %4 = tensor.empty() : tensor<12x56x56x64xf32> + %5 = linalg.add ins(%1, %3 : tensor<12x56x56x64xf32>, tensor<12x56x56x64xf32>) outs(%4: tensor<12x56x56x64xf32>) -> tensor<12x56x56x64xf32> + return %5 : tensor<12x56x56x64xf32> +} + +func.func @add_tensor_test5() -> tensor<1x1x1x8xf32> { + %cst = arith.constant 1.000000e+00 : f32 + %init = tensor.empty() : tensor<1x8xf32> + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x8xf32>) -> tensor<1x8xf32> + %slice = tensor.extract_slice %fill[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> + %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x1x8xf32> + return %expand : tensor<1x1x1x8xf32> +} + +func.func @tensor_collapse_shape_test0(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xf32> into tensor<6xf32> + return %0 : tensor<6xf32> +} + +func.func @tensor_bitcast_test0(%input: tensor<2xi32>) -> tensor<2xf32> { + %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32> + %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32> + return %1 : tensor<2xf32> +} + +func.func @tensor_static_concat_test0(%arg0 : tensor<1x1x64xf32>, + %arg1: tensor<1x1x64xf32>) -> tensor<1x1x128xf32> { + %0 = tensor.concat dim(2) %arg0, %arg1 + : (tensor<1x1x64xf32>, tensor<1x1x64xf32>) -> tensor<1x1x128xf32> + return %0 : tensor<1x1x128xf32> +} + +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{elementwise binary}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} From 8d65cbff7beb6ddf72e814b002b6e01499081216 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 28 May 2024 11:36:07 +0800 Subject: [PATCH 02/66] add LoopInvariantCodeMotion and CSE --- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 13 ++++++++++++- lib/gc/Transforms/LowerTileVectorPass.cpp | 1 + 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 120d12c3e..9ebcf2fb2 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -20,7 +20,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/Support/Casting.h" #include #include @@ -57,6 +59,7 @@ bool is_innermost_operation(Operation *op) { inner_most = false; return WalkResult::interrupt(); } + return WalkResult::advance(); }); return inner_most; } @@ -861,9 +864,10 @@ void generateGroupOpVectorizedIR( // 3 Update loop result uses updateLoopResultUses(idx, opGroups.size(), groupResultYeildSet, func, &forOp.value(), mapOpResultToYield); + moveLoopInvariantCode(forOp.value()); } -/// Pass that lower to tile vector. +/// Pass that lower to physical vector. struct CPUPhysicalRegisterPass : public impl::CPUPhysicalRegisterPassBase { @@ -882,6 +886,8 @@ struct CPUPhysicalRegisterPass // dependency. // d. reduction. Need to analysis broadcast dim and the // data dependency. + // Same group operations have no data dependencies. They can be fused into a + // common for loop body. // Using queue to store the operation order. In order to ensure that // subsequent moves to the operation will not cause semantic changes. @@ -953,6 +959,11 @@ struct CPUPhysicalRegisterPass groupOpDestination, mapOpResultToYield, func, opPermuationMap); } + + // 4. Some IR cleanup work + DominanceInfo domInfo; + auto reWriter = IRRewriter(func); + eliminateCommonSubExpressions(reWriter, domInfo, func); } }; } // namespace diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index 900b685f7..1a8fb6476 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -39,6 +39,7 @@ bool is_innermost_ir(Operation *op) { inner_most = false; return WalkResult::interrupt(); } + return WalkResult::advance(); }); return inner_most; } From e6037e25c282eeddbba4275430cbd80db17f835c Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 30 May 2024 14:52:35 +0800 Subject: [PATCH 03/66] update for result use rewriter --- include/gc/Transforms/Passes.h | 9 +++ lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 52 +++++++++------- lib/gc/Transforms/LowerTileVectorPass.cpp | 6 ++ .../gc/transforms/cpu-vetor-distribution.mlir | 60 ++++++++++++------- 4 files changed, 84 insertions(+), 43 deletions(-) diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index 243a6f4f6..18299a81c 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -9,9 +9,18 @@ #ifndef GC_PASSES_H #define GC_PASSES_H +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" namespace mlir { +namespace vector { +#define GEN_PASS_DECL +#include "gc/Transforms/Passes.h.inc" +/// Creates an instance of the `vector.multi_reduction` lowering pass. +std::unique_ptr createLowerVectorMultiReductionPass( + VectorMultiReductionLowering option = + VectorMultiReductionLowering::InnerParallel); +} // namespace vector namespace gc { #define GEN_PASS_DECL diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 9ebcf2fb2..a4d2e4e6c 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -6,6 +6,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include "gc/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -15,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" @@ -25,9 +27,11 @@ #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/Support/Casting.h" #include +#include #include #include #include +#include namespace mlir { namespace gc { @@ -149,7 +153,6 @@ void setOpVectorizationPermutationMap(Operation *op, IRRewriter &rewriter, void maybeYieldValue(OpBuilder &b, Location loc, ValueRange value) { bool hasRetVal = !value.empty(); if (hasRetVal) { - assert(!value.empty() && "Expected non-empty value"); b.create(loc, value); } else { b.create(loc); @@ -171,6 +174,9 @@ void checkAndSetOperand( if (llvm::dyn_cast(op) || llvm::dyn_cast(op)) { assert(opPermuationMap.contains(op)); + op->dump(); + std::cout << inductionVars.size() << std::endl; + auto permutationMap = opPermuationMap.at(op); auto dimExpr = permutationMap.getResults(); @@ -535,8 +541,7 @@ Operation *createTensorEmptyBefore(Operation *op) { dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), op->getResult(0), i)); } - return reWriter.create(reWriter.getUnknownLoc(), - rtType.getShape(), + return reWriter.create(op->getLoc(), rtType.getShape(), rtType.getElementType(), dynDims); } @@ -706,7 +711,7 @@ void analysisGroupOperationResults( llvm::SmallVector, 8> &groupResultYeildSet, llvm::SmallVector, 8> &groupOpDestination) { llvm::DenseMap> srcOpCanoniclizedMap; - + IRRewriter rewriter(func); func.walk([&](Operation *op) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { auto sourceOp = opd.getDefiningOp(); @@ -757,6 +762,13 @@ void analysisGroupOperationResults( } } }); + // If the group operations do not have result need to be returned, these are + // useless code. + for (auto [idx, grp] : enumerate(opGroups)) { + if (groupResultYeildSet[idx].empty()) { + std::queue().swap(grp); + } + } LDBG("Complete analysis group operation results\n"); } @@ -815,24 +827,25 @@ mlir::FailureOr generateVectorizedForLoop( } void updateLoopResultUses( - const size_t groupIdx, const size_t groupSize, + const size_t groupIdx, llvm::SmallVector, 8> &groupResultYeildSet, - const func::FuncOp &func, scf::ForOp *forOp, - IRMapping &mapOpResultToYield) { + scf::ForOp *forOp) { + if (groupResultYeildSet[groupIdx].empty()) { + return; + } + IRRewriter rewriter(*forOp); + OpBuilder::InsertionGuard g(rewriter); + // Only different group operation operand need to be replaced due to same + // group operation should directly use original operand. + auto producerOp = groupResultYeildSet[groupIdx].front().getDefiningOp(); + auto needToReplaced = [&](OpOperand &operand) { + return producerOp->getBlock() != operand.getOwner()->getBlock(); + }; // update loop result uses for (auto [retIdx, rt] : llvm::enumerate(groupResultYeildSet[groupIdx])) { - mapOpResultToYield.map(rt, forOp->getResult(retIdx)); + producerOp = rt.getDefiningOp(); + rewriter.replaceUsesWithIf(rt, forOp->getResult(retIdx), needToReplaced); } - auto currentIdx = groupIdx; - func->walk([&](Operation *op) { - for (auto [opdIdx, opd] : llvm::enumerate(op->getOperands())) { - if (groupResultYeildSet[currentIdx].contains(opd) && - opd.getDefiningOp() != op && - opd.getDefiningOp()->getBlock() != op->getBlock()) { - op->setOperand(opdIdx, mapOpResultToYield.getValueMap().at(opd)); - } - } - }); } void generateGroupOpVectorizedIR( @@ -862,8 +875,7 @@ void generateGroupOpVectorizedIR( return; } // 3 Update loop result uses - updateLoopResultUses(idx, opGroups.size(), groupResultYeildSet, func, - &forOp.value(), mapOpResultToYield); + updateLoopResultUses(idx, groupResultYeildSet, &forOp.value()); moveLoopInvariantCode(forOp.value()); } diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index 1a8fb6476..944e38775 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -6,6 +6,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include "gc/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -14,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -477,6 +479,10 @@ struct LowerTileVectorPass tensor::populateFoldTensorSubsetOpPatterns(patterns); tensor::populateFoldTensorEmptyPatterns(patterns, true); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + vector::VectorTransformsOptions vectorTransformOptions; + vector::populateVectorMultiReductionLoweringPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); + // vector::populateVectorShapeCastLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 69f5ed058..5b3235c71 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -1,30 +1,44 @@ // RUN: gc-opt --split-input-file --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all -- %s // CHECK-LABEL: func @add_tensor -func.func @add_tensor_test0(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { - %0 = tensor.empty() : tensor<4x8x16xf32> - %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> - %2 = linalg.add ins(%1, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> - return %2 : tensor<4x8x16xf32> -} +// func.func @add_tensor_test0(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { +// %0 = tensor.empty() : tensor<4x8x16xf32> +// %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> +// %2 = linalg.add ins(%1, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> +// return %2 : tensor<4x8x16xf32> +// } -func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, - %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) - -> tensor<512x512xf32> { - // Matrix-matrix multiplication. - %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) - outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> +// func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, +// %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) +// -> tensor<512x512xf32> { +// // Matrix-matrix multiplication. +// %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) +// outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> - // Elementwise addition. - %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } - ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) - outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> +// // Elementwise addition. +// %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } +// ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) +// outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> - // Elementwise max with 0 (ReLU). - %c0f = arith.constant 0.0 : f32 - // expected-remark @below {{elementwise binary}} - %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } - ins(%biased, %c0f : tensor<512x512xf32>, f32) - outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> - func.return %relued : tensor<512x512xf32> +// // Elementwise max with 0 (ReLU). +// %c0f = arith.constant 0.0 : f32 +// // expected-remark @below {{elementwise binary}} +// %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } +// ins(%biased, %c0f : tensor<512x512xf32>, f32) +// outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> +// func.return %relued : tensor<512x512xf32> +// } +func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { + %0 = tensor.empty() : tensor<16x64xf32> + %reduce = linalg.reduce + ins(%arg0:tensor<16x32x64xf32>) + outs(%0:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %1 = arith.addf %out, %in: f32 + linalg.yield %1: f32 + } + %2 = tensor.expand_shape %reduce [[0],[1, 2]] : tensor<16x64xf32> into tensor<16x1x64xf32> + return %2 : tensor<16x1x64xf32> } + From 24261e77cb3b5ff48b8fc814f6d1b2db74577bcd Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 30 May 2024 16:43:04 +0800 Subject: [PATCH 04/66] move functions in class --- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 601 ++++++++++-------- lib/gc/Transforms/LowerTileVectorPass.cpp | 4 +- .../gc/transforms/cpu-vetor-distribution.mlir | 10 +- 3 files changed, 339 insertions(+), 276 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index a4d2e4e6c..bc4163f63 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -53,7 +53,8 @@ bool isSpecialOp(Operation *op) { llvm::isa(op) || llvm::isa(op) || llvm::isa(op) || - llvm::isa(op); + llvm::isa(op) || + llvm::isa(op); } bool is_innermost_operation(Operation *op) { @@ -113,6 +114,152 @@ mlir::FailureOr getOperationVectorType(Operation *op) { }); } +// 1. Classify operaions: +// classify the operations into : +// a. reorder, transpose. Reorder(or transpose) dim may bring data +// dependency. +// b. elemenwise. Those operations can be fused into a common for loop. +// c. broadcast. Need to analysis broadcast dim and the data +// dependency. +// d. reduction. Need to analysis broadcast dim and the +// data dependency. +// Same group operations have no data dependencies. They can be fused into a +// common for loop body. + +// Using queue to store the operation order. In order to ensure that +// subsequent moves to the operation will not cause semantic changes. +class VectorFusionStrategy { +public: + llvm::SmallVector, 8> &getOpGroups() { + return opGroups; + } + llvm::DenseMap &getOpGroupIndexMap() { + return opGroupIndexMap; + } + + func::FuncOp getFunc() { return func; } + + VectorFusionStrategy() = default; + VectorFusionStrategy(func::FuncOp func) : func(func) {} + + void + classifyOperations(func::FuncOp func, + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap); + + // run the vector fusion strategy + void run(); + +private: + llvm::SmallVector, 8> opGroups; + // query current operation in which group, return group index + llvm::DenseMap opGroupIndexMap; + + func::FuncOp func; +}; + +void VectorFusionStrategy::run() { + classifyOperations(func, opGroups, opGroupIndexMap); +} + +enum CanonicalizerKind { OperationsGroup, Operations }; + +struct CanonicalizerVectorOperation { + func::FuncOp func; + IRRewriter rewriter; + VectorFusionStrategy fusionStrategy; + CanonicalizerKind kind; + + // analysis the operation's operands and results + llvm::SmallVector, 8> groupOpResults, groupOpIterArgs; + + // store read and write operations permutation maps in order to convenient + // to replace loop induction var + llvm::DenseMap opPermuationMap; + + CanonicalizerVectorOperation( + func::FuncOp func, + CanonicalizerKind kind = CanonicalizerKind::OperationsGroup) + : func(func), rewriter(func), kind(kind) { + // vector operation fusion + if (kind == CanonicalizerKind::OperationsGroup) { + fusionStrategy = VectorFusionStrategy(func); + fusionStrategy.run(); + } + } + + void generateGroupOpVectorizedIR( + const int idx, std::queue &grp, + llvm::DenseMap &opGroupIndexMap); + + void analysisGroupOperaionOperandsResults( + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap); + + void analysisGroupOperationResults( + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap); + + void run(); +}; + +void CanonicalizerVectorOperation::run() { + // 1. Analysis the operation's operands and results + // We need to analyze which operation results are needed by other + // operations, and we need to pass these results correctly. Mapping the + // operation result value to forloop yeild result value. We can replace the + // operation operand as: map(operand, forloop yield result) -> operand = + // loop yield result We put all the operation result into this map. + + // 2.a. Find what results should be generated by current group for + // using as operands to other operations? + + // Traverse all operations. If the operand of operations in other groups or + // outside the group is the result of the current group operation, then the + // current operation needs to generate a result. We use `setvector` to save + // the results that need to be generated by the current group. + + // 2.b. What operands are needed to find in the current group, and where + // can they be obtained ? + + // Thanks to 2.a, we get the result generated by the operations of + // each group, and this result will use `for loop yield` to generate a + // new result. Since the scope of the parent block of mlir is covered + // the current operation, the current operation does not need to pass these + // `for loop results` to the `iter args` of the required `for loop`. It + // only needs to replace the operand of the current operation with the + // corresponding `for loop yield result`. + + // However, for some operations that are not DPS, we need to canonicalize + // them. Canonicalization means that the operand of this operation is a + // vector but we can't get this vector due to it locates in another block + // which has a different scope. Therefore, it is necessary to write the + // vector results into a temporary tensor to save it. Then the vector needs + // to be read from the tensor before the current operation operate on it. + // Therefore, `empty tensor`, `transfer_write` and `transfer_read` need to + // be inserted at target place. + + // Query groupResultYeildSet to map operaion result value to scf.yield + // result value. + if (kind == CanonicalizerKind::OperationsGroup) { + analysisGroupOperaionOperandsResults(fusionStrategy.getOpGroups(), + fusionStrategy.getOpGroupIndexMap()); + // 3.Generate vectorized IR for each operation group + for (auto [idx, grp] : llvm::enumerate(fusionStrategy.getOpGroups())) { + generateGroupOpVectorizedIR(idx, grp, + fusionStrategy.getOpGroupIndexMap()); + } + + // 4. Some IR cleanup work + DominanceInfo domInfo; + auto reWriter = IRRewriter(func); + eliminateCommonSubExpressions(reWriter, domInfo, func); + } else { + // TODO: need to add directly canonicalize operations + // generateGroupOpVectorizedIR(idx, grp, fusionStrategy.opGroupIndexMap); + } +} + // Filter out the operations that can be vectorized. We are only interested in // operations that do not contain any for loops(innermost IR). [[nodiscard]] bool filterOperation(Operation *op) { @@ -129,12 +276,18 @@ mlir::FailureOr getOperationVectorType(Operation *op) { return true; } -// Since we rewrote transfer_read and transfer_write, the `permutationmap` must +// Since we rewrite transfer_read and transfer_write, the `permutationmap` must // be changed. void setOpVectorizationPermutationMap(Operation *op, IRRewriter &rewriter, - RankedTensorType tensorType) { + const RankedTensorType &tensorType, + const AffineMap &permutationMap) { + + auto dimExpr = permutationMap.getResults(); + auto lastDim = mlir::dyn_cast(dimExpr.back()); + assert(mlir::isa(lastDim)); + SmallVector affineExprs; - affineExprs.push_back(rewriter.getAffineDimExpr(tensorType.getRank() - 1)); + affineExprs.push_back(lastDim); auto destAffineMap = AffineMap::get(tensorType.getRank(), 0, affineExprs, rewriter.getContext()); SmallVector inBounds(1, true); @@ -174,8 +327,6 @@ void checkAndSetOperand( if (llvm::dyn_cast(op) || llvm::dyn_cast(op)) { assert(opPermuationMap.contains(op)); - op->dump(); - std::cout << inductionVars.size() << std::endl; auto permutationMap = opPermuationMap.at(op); @@ -350,15 +501,11 @@ bool isNeedNewGroup(llvm::SmallVector, 8> &opGroups, void addOperationToGroup( llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap, Operation *op, - llvm::SmallVector &groupsShapes) { + llvm::DenseMap &opGroupIndexMap, Operation *op) { // if (isNeedNewGroup(opGroups, op)) { opGroups.emplace_back(std::queue()); } - if (opGroups.size() != groupsShapes.size()) { - groupsShapes.emplace_back(getOperationVectorType(op).value()); - } opGroups.back().push(op); opGroupIndexMap[op] = opGroups.size() - 1; } @@ -366,14 +513,17 @@ void addOperationToGroup( // We classify the operations we are interested in after filtering. Operations // of in the same group have no data dependencies. Those operations can generate // a same outter for loop. -void classifyOperations(func::FuncOp func, - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap, - llvm::SmallVector &groupsShapes) { +void VectorFusionStrategy::classifyOperations( + func::FuncOp func, llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap) { + if (opGroups.empty()) { + // dummpy + opGroups.emplace_back(std::queue()); + } func->walk([&](Operation *op) { TypeSwitch(op).Default([&](Operation *op) { if (filterOperation(op)) { - addOperationToGroup(opGroups, opGroupIndexMap, op, groupsShapes); + addOperationToGroup(opGroups, opGroupIndexMap, op); } }); }); @@ -422,113 +572,6 @@ void setOperationOperandResult( } }; -/// Rewrite the operations in the group to vectorized form. -void rewriteOperationAsVectorize( - const std::queue &groupOps, - llvm::DenseMap &opMap, IRRewriter &rewriter, - llvm::DenseMap &opPermuationMap) { - std::queue transformQueue(groupOps); - - auto getVectorzedType = [](Operation *op) -> VectorType { - // Check that the operation type can be broken - // down into a loop. - auto baseType = getOperationVectorType(op); - if (failed(baseType)) { - LDBG("Failed to get vector type for operation: " << *op << "\n"); - assert(false && "Failed to get vector type for operation"); - return VectorType(); - } - auto vectorizedType = baseType.value(); - const int loop_step = getDataTypeMAXSIMDLength(vectorizedType); - return VectorType::get({loop_step}, vectorizedType.getElementType()); - }; - - while (!transformQueue.empty()) { - auto op = transformQueue.front(); - transformQueue.pop(); - auto lowerResult = - TypeSwitch(op) - .Case( - [&](vector::TransferWriteOp transferWriteOp) { - auto newOperandType = getVectorzedType(transferWriteOp); - if (!isSpecialOp( - transferWriteOp->getOperand(0).getDefiningOp())) { - opPermuationMap.insert( - {transferWriteOp, transferWriteOp.getPermutationMap()}); - transferWriteOp->getOperand(0).setType(newOperandType); - setOpVectorizationPermutationMap( - transferWriteOp, rewriter, - transferWriteOp->getResult(0) - .getType() - .dyn_cast()); - } - - return success(); - }) - .Case( - [&](vector::TransferReadOp transferReadOp) { - auto newOperandType = getVectorzedType(transferReadOp); - auto users = transferReadOp->getUsers(); - bool isUserSpecial = false; - for (auto *opUse : users) { - if (isSpecialOp(opUse)) { - isUserSpecial = true; - break; - } - } - if (!isUserSpecial) { - opPermuationMap.insert( - {transferReadOp, transferReadOp.getPermutationMap()}); - transferReadOp->getResult(0).setType(newOperandType); - setOpVectorizationPermutationMap( - transferReadOp, rewriter, - transferReadOp.getSource() - .getType() - .dyn_cast()); - } - - return success(); - }) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - return success(); - }) - .Default([&](Operation *op) { - if (isSpecialOp(op)) { - return success(); - } - setOperationOperandResult(op, getVectorzedType(op), opMap); - return success(); - }); - if (failed(lowerResult)) { - LDBG("Failed to rewrite operation: " << *op << "\n"); - assert(false && "Failed to rewrite operation"); - } - } -} - -// analysis operation' operands are coming from which operation's result -void analysisOperaionOperandSource( - size_t idx, std::queue &grp, - llvm::DenseMap &opGroupIndexMap, - llvm::SmallVector, 8> &groupOperandNeedSet) { - auto tmpOpQueue(grp); - llvm::SetVector opOperands; - while (!tmpOpQueue.empty()) { - auto t = tmpOpQueue.front(); - for (auto x : t->getOperands()) { - // not in the same group - if (opGroupIndexMap.contains(x.getDefiningOp()) && - opGroupIndexMap[x.getDefiningOp()] != idx) { - groupOperandNeedSet[idx].insert(x); - } else { - groupOperandNeedSet[idx].insert(x); - } - } - tmpOpQueue.pop(); - } -} - Operation *createTensorEmptyBefore(Operation *op) { auto rtType = op->getResultTypes()[0].dyn_cast(); IRRewriter reWriter(op); @@ -615,6 +658,118 @@ Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { /*inBounds=*/inBoundsVal); } +/// Rewrite the operations in the group to vectorized form. +void rewriteOperationAsVectorize( + const std::queue &groupOps, + const llvm::DenseMap &opMap, IRRewriter &rewriter, + llvm::DenseMap &opPermuationMap) { + std::queue transformQueue(groupOps); + + auto getVectorzedType = [](Operation *op) -> VectorType { + // Check that the operation type can be broken + // down into a loop. + auto baseType = getOperationVectorType(op); + if (failed(baseType)) { + LDBG("Failed to get vector type for operation: " << *op << "\n"); + assert(false && "Failed to get vector type for operation"); + return VectorType(); + } + auto vectorizedType = baseType.value(); + const int loop_step = getDataTypeMAXSIMDLength(vectorizedType); + return VectorType::get({loop_step}, vectorizedType.getElementType()); + }; + + while (!transformQueue.empty()) { + auto op = transformQueue.front(); + transformQueue.pop(); + auto lowerResult = + TypeSwitch(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) { + IRRewriter rewriter(transferWriteOp); + auto newOperandType = getVectorzedType(transferWriteOp); + + if (!isSpecialOp( + transferWriteOp->getOperand(0).getDefiningOp())) { + + opPermuationMap.insert( + {transferWriteOp, transferWriteOp.getPermutationMap()}); + transferWriteOp->getOperand(0).setType(newOperandType); + setOpVectorizationPermutationMap( + transferWriteOp, rewriter, + transferWriteOp->getResult(0) + .getType() + .dyn_cast(), + transferWriteOp.getPermutationMap()); + } + + return success(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) { + auto newOperandType = getVectorzedType(transferReadOp); + auto users = transferReadOp->getUsers(); + bool isUserSpecial = false; + for (auto *opUse : users) { + if (isSpecialOp(opUse)) { + isUserSpecial = true; + break; + } + } + if (!isUserSpecial) { + opPermuationMap.insert( + {transferReadOp, transferReadOp.getPermutationMap()}); + transferReadOp->getResult(0).setType(newOperandType); + setOpVectorizationPermutationMap( + transferReadOp, rewriter, + transferReadOp.getSource() + .getType() + .dyn_cast(), + transferReadOp.getPermutationMap()); + } + + return success(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + return success(); + }) + .Default([&](Operation *op) { + if (isSpecialOp(op)) { + return success(); + } + setOperationOperandResult(op, getVectorzedType(op), opMap); + return success(); + }); + if (failed(lowerResult)) { + LDBG("Failed to rewrite operation: " << *op << "\n"); + assert(false && "Failed to rewrite operation"); + } + } +} + +// analysis operation' operands are coming from which operation's result +// void analysisOperaionOperandSource( +// size_t idx, std::queue &grp, +// llvm::DenseMap &opGroupIndexMap, +// llvm::SmallVector, 8> &groupOperandNeedSet) { +// auto tmpOpQueue(grp); +// llvm::SetVector opOperands; +// while (!tmpOpQueue.empty()) { +// auto t = tmpOpQueue.front(); +// for (auto x : t->getOperands()) { +// // not in the same group +// if (opGroupIndexMap.contains(x.getDefiningOp()) && +// opGroupIndexMap[x.getDefiningOp()] != idx) { +// groupOperandNeedSet[idx].insert(x); +// } else { +// groupOperandNeedSet[idx].insert(x); +// } +// } +// tmpOpQueue.pop(); +// } +// } + // canonicalizing operation as tensor empty and transfer write the operation // result into the empty tensor [[nodiscard]] std::pair @@ -652,16 +807,16 @@ mlir::FailureOr getOperationDestnationOperand(Operation *op) { } // analysis operations of current group need which operation's result value -void analysisGroupOperationOperands( - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap, - llvm::SmallVector, 8> &groupOperandNeedSet) { - - for (auto [idx, grp] : enumerate(opGroups)) { - analysisOperaionOperandSource(idx, grp, opGroupIndexMap, - groupOperandNeedSet); - } -} +// void analysisGroupOperationOperands( +// llvm::SmallVector, 8> &opGroups, +// llvm::DenseMap &opGroupIndexMap, +// llvm::SmallVector, 8> &groupOperandNeedSet) { + +// for (auto [idx, grp] : enumerate(opGroups)) { +// analysisOperaionOperandSource(idx, grp, opGroupIndexMap, +// groupOperandNeedSet); +// } +// } // TODO: need to rewrite reduce // llvm::SmallVector & @@ -704,12 +859,9 @@ void updateOpOperandResultInGroups( // analysis operation result of current group whether needed by other // operation which out of current group -void analysisGroupOperationResults( - func::FuncOp &func, llvm::SmallVector, 8> &opGroups, - IRMapping &mapOpResultToYield, - llvm::DenseMap &opGroupIndexMap, - llvm::SmallVector, 8> &groupResultYeildSet, - llvm::SmallVector, 8> &groupOpDestination) { +void CanonicalizerVectorOperation::analysisGroupOperationResults( + llvm::SmallVector, 8> &opGroups, + llvm::DenseMap &opGroupIndexMap) { llvm::DenseMap> srcOpCanoniclizedMap; IRRewriter rewriter(func); func.walk([&](Operation *op) { @@ -731,9 +883,8 @@ void analysisGroupOperationResults( updateOpOperandResultInGroups(opGroups, opGroupIndexMap, sourceOpGid, sourceOp, init, result); - groupOpDestination[sourceOpGid].insert(init); - groupResultYeildSet[sourceOpGid].insert(result); - mapOpResultToYield.map(result, result); + groupOpIterArgs[sourceOpGid].insert(init); + groupOpResults[sourceOpGid].insert(result); } auto opInit = canonicalizeCurrentOperation( @@ -751,11 +902,8 @@ void analysisGroupOperationResults( opGroupIndexMap[op], op, opInit); } else { - groupOpDestination[sourceOpGid].insert(dstRet.value()); - groupResultYeildSet[sourceOpGid].insert(opd); - - // just map to it self, placeholder - mapOpResultToYield.map(opd, opd); + groupOpIterArgs[sourceOpGid].insert(dstRet.value()); + groupOpResults[sourceOpGid].insert(opd); } } } @@ -765,28 +913,30 @@ void analysisGroupOperationResults( // If the group operations do not have result need to be returned, these are // useless code. for (auto [idx, grp] : enumerate(opGroups)) { - if (groupResultYeildSet[idx].empty()) { + if (groupOpResults[idx].empty()) { std::queue().swap(grp); } } LDBG("Complete analysis group operation results\n"); } -void analysisGroupOperaionOperandsResults( +void CanonicalizerVectorOperation::analysisGroupOperaionOperandsResults( llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap, - llvm::SmallVector, 8> &groupOperandNeedSet, - func::FuncOp &func, - llvm::SmallVector, 8> &groupResultYeildSet, - IRMapping &mapOpResultToYield, - llvm::SmallVector, 8> &groupOpDestination) { + llvm::DenseMap &opGroupIndexMap) { + // prepare + if (opGroups.size() != groupOpResults.size()) { + for (size_t i = 0; i < opGroups.size(); i++) { + groupOpResults.emplace_back(llvm::SetVector()); + groupOpIterArgs.emplace_back(llvm::SetVector()); + } + LDBG("Size of groupOpResults is : " << groupOpResults.size()); + } + // Operands - analysisGroupOperationOperands(opGroups, opGroupIndexMap, - groupOperandNeedSet); + // analysisGroupOperationOperands(opGroups, opGroupIndexMap); + // Results - analysisGroupOperationResults(func, opGroups, mapOpResultToYield, - opGroupIndexMap, groupResultYeildSet, - groupOpDestination); + analysisGroupOperationResults(opGroups, opGroupIndexMap); } mlir::FailureOr generateVectorizedForLoop( @@ -826,56 +976,53 @@ mlir::FailureOr generateVectorizedForLoop( return forOp; } -void updateLoopResultUses( - const size_t groupIdx, - llvm::SmallVector, 8> &groupResultYeildSet, - scf::ForOp *forOp) { - if (groupResultYeildSet[groupIdx].empty()) { +void updateLoopResultUses(llvm::SetVector &opResults, + scf::ForOp *forOp) { + if (opResults.empty()) { return; } IRRewriter rewriter(*forOp); OpBuilder::InsertionGuard g(rewriter); // Only different group operation operand need to be replaced due to same // group operation should directly use original operand. - auto producerOp = groupResultYeildSet[groupIdx].front().getDefiningOp(); + auto producerOp = opResults.front().getDefiningOp(); auto needToReplaced = [&](OpOperand &operand) { return producerOp->getBlock() != operand.getOwner()->getBlock(); }; // update loop result uses - for (auto [retIdx, rt] : llvm::enumerate(groupResultYeildSet[groupIdx])) { + for (auto [retIdx, rt] : llvm::enumerate(opResults)) { producerOp = rt.getDefiningOp(); rewriter.replaceUsesWithIf(rt, forOp->getResult(retIdx), needToReplaced); } } -void generateGroupOpVectorizedIR( - std::queue &grp, const size_t idx, - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap, - llvm::SmallVector &groupsShapes, - llvm::SmallVector, 8> &groupResultYeildSet, - llvm::SmallVector, 8> &groupOpDestination, - IRMapping &mapOpResultToYield, func::FuncOp &func, - llvm::DenseMap &opPermuationMap) { +void CanonicalizerVectorOperation::generateGroupOpVectorizedIR( + const int idx, std::queue &grp, + llvm::DenseMap &opGroupIndexMap) { if (grp.empty()) { LDBG("Current operation Group is empty."); return; } + auto getType = getOperationVectorType(grp.front()); + if (failed(getType)) { + LDBG("Failed to get vector type for operation: " << *grp.front() << "\n"); + return; + } + auto opShapes = getType.value(); IRRewriter rewriter(grp.back()); rewriter.setInsertionPointAfter(grp.back()); // 1. Rewrite operation as vectorized form - rewriteOperationAsVectorize(opGroups[idx], opGroupIndexMap, rewriter, - opPermuationMap); + rewriteOperationAsVectorize(grp, opGroupIndexMap, rewriter, opPermuationMap); // 2. Generate loop - auto forOp = generateVectorizedForLoop( - rewriter, groupResultYeildSet[idx], groupOpDestination[idx], - groupsShapes[idx], opGroups[idx], opPermuationMap); + auto forOp = generateVectorizedForLoop(rewriter, groupOpResults[idx], + groupOpIterArgs[idx], opShapes, grp, + opPermuationMap); // special operation do not need to change anything if (failed(forOp)) { return; } // 3 Update loop result uses - updateLoopResultUses(idx, groupResultYeildSet, &forOp.value()); + updateLoopResultUses(groupOpResults[idx], &forOp.value()); moveLoopInvariantCode(forOp.value()); } @@ -889,93 +1036,9 @@ struct CPUPhysicalRegisterPass RewritePatternSet patterns(ctx); auto func = getOperation(); - // 1. Classify operaions: - // classify the operations into : - // a. reorder, transpose. Reorder(or transpose) dim may bring data - // dependency. - // b. elemenwise. Those operations can be fused into a common for loop. - // c. broadcast. Need to analysis broadcast dim and the data - // dependency. - // d. reduction. Need to analysis broadcast dim and the - // data dependency. - // Same group operations have no data dependencies. They can be fused into a - // common for loop body. - - // Using queue to store the operation order. In order to ensure that - // subsequent moves to the operation will not cause semantic changes. - llvm::SmallVector, 8> opGroups; - llvm::SmallVector groupsShapes; - // dummy - opGroups.emplace_back(std::queue()); - - // query current operation in which group, return group index - llvm::DenseMap opGroupIndexMap; - classifyOperations(func, opGroups, opGroupIndexMap, groupsShapes); - - // 2. Analysis the operation's operands and results - // We need to analyze which operation results are needed by other - // operations, and we need to pass these results correctly. Mapping the - // operation result value to forloop yeild result value. We can replace the - // operation operand as: map(operand, forloop yield result) -> operand = - // loop yield result We put all the operation result into this map. - - // 2.a. Find what results should be generated by current group for - // using as operands to other operations? - - // Traverse all operations. If the operand of operations in other groups or - // outside the group is the result of the current group operation, then the - // current operation needs to generate a result. We use `setvector` to save - // the results that need to be generated by the current group. - - // 2.b. What operands are needed to find in the current group, and where - // can they be obtained ? - - // Thanks to 2.a, we get the result generated by the operations of - // each group, and this result will use `for loop yield` to generate a - // new result. Since the scope of the parent block of mlir is covered - // the current operation, the current operation does not need to pass these - // `for loop results` to the `iter args` of the required `for loop`. It - // only needs to replace the operand of the current operation with the - // corresponding `for loop yield result`. - - // However, for some operations that are not DPS, we need to canonicalize - // them. Canonicalization means that the operand of this operation is a - // vector but we can't get this vector due to it locates in another block - // which has a different scope. Therefore, it is necessary to write the - // vector results into a temporary tensor to save it. Then the vector needs - // to be read from the tensor before the current operation operate on it. - // Therefore, `empty tensor`, `transfer_write` and `transfer_read` need to - // be inserted at target place. - - llvm::SmallVector, 8> groupOperandNeedSet( - opGroups.size(), llvm::SetVector()), - groupResultYeildSet(opGroups.size(), llvm::SetVector()), - groupOpDestination(opGroups.size(), llvm::SetVector()); - // Query groupResultYeildSet to map operaion result value to scf.yield - // result value. - IRMapping mapOpResultToYield; - analysisGroupOperaionOperandsResults( - opGroups, opGroupIndexMap, groupOperandNeedSet, func, - groupResultYeildSet, mapOpResultToYield, groupOpDestination); - - OpBuilder builder(ctx); - // store read and write operations permutation maps in order to convenient - // to replace loop induction var - llvm::DenseMap opPermuationMap; - - // 3.Generate vectorized IR for each operation group - for (auto [idx, grp] : llvm::enumerate(opGroups)) { - - generateGroupOpVectorizedIR(grp, idx, opGroups, opGroupIndexMap, - groupsShapes, groupResultYeildSet, - groupOpDestination, mapOpResultToYield, func, - opPermuationMap); - } - - // 4. Some IR cleanup work - DominanceInfo domInfo; - auto reWriter = IRRewriter(func); - eliminateCommonSubExpressions(reWriter, domInfo, func); + // canonicalize vector operation + CanonicalizerVectorOperation canonicalizer(func); + canonicalizer.run(); } }; } // namespace diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index 944e38775..89b8e5354 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -480,8 +480,8 @@ struct LowerTileVectorPass tensor::populateFoldTensorEmptyPatterns(patterns, true); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); vector::VectorTransformsOptions vectorTransformOptions; - vector::populateVectorMultiReductionLoweringPatterns( - patterns, vectorTransformOptions.vectorMultiReductionLowering); + // vector::populateVectorMultiReductionLoweringPatterns( + // patterns, vectorTransformOptions.vectorMultiReductionLowering); // vector::populateVectorShapeCastLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 5b3235c71..e4e39ec33 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -1,11 +1,10 @@ // RUN: gc-opt --split-input-file --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all -- %s // CHECK-LABEL: func @add_tensor -// func.func @add_tensor_test0(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { -// %0 = tensor.empty() : tensor<4x8x16xf32> -// %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> -// %2 = linalg.add ins(%1, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> -// return %2 : tensor<4x8x16xf32> +// func.func @add_tensor_test0(%arg0: tensor<4x8x1024xf32>, %arg1: tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32> { +// %0 = tensor.empty() : tensor<4x8x1024xf32> +// %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x1024xf32>, tensor<4x8x1024xf32>) outs(%0: tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32> +// return %1 : tensor<4x8x1024xf32> // } // func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, @@ -28,6 +27,7 @@ // outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> // func.return %relued : tensor<512x512xf32> // } + func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { %0 = tensor.empty() : tensor<16x64xf32> %reduce = linalg.reduce From 0709481f1c3b8f52b3b215353bb8bbcdb954edd5 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 3 Jun 2024 17:10:15 +0800 Subject: [PATCH 05/66] backup multireduction canonicalization --- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 462 ++++++++++++++---- 1 file changed, 364 insertions(+), 98 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index bc4163f63..8ab467960 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -95,7 +95,13 @@ mlir::FailureOr getOperationVectorType(Operation *op) { .Case( [&](vector::TransferWriteOp transferWriteOp) -> mlir::FailureOr { - return transferWriteOp.getVectorType(); + auto retType = mlir::dyn_cast( + transferWriteOp->getOperand(0).getType()); + if (retType) { + return retType; + } + LDBG("TransferWrite Operation has wrong vector to write."); + return failure(); }) .Case([&](vector::TransferReadOp transferReadOp) -> mlir::FailureOr { @@ -114,6 +120,298 @@ mlir::FailureOr getOperationVectorType(Operation *op) { }); } +VectorType getVectorzedType(Operation *op) { + // Check that the operation type can be broken + // down into a loop. + auto baseType = getOperationVectorType(op); + if (failed(baseType)) { + LDBG("Failed to get vector type for operation: " << *op << "\n"); + assert(false && "Failed to get vector type for operation"); + return VectorType(); + } + auto vectorizedType = baseType.value(); + const int loop_step = getDataTypeMAXSIMDLength(vectorizedType); + return VectorType::get({loop_step}, vectorizedType.getElementType()); +} + +// Since we rewrite transfer_read and transfer_write, the `permutationmap` must +// be changed. +void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, + const RankedTensorType &tensorType, + const AffineMap &permutationMap) { + + auto dimExpr = permutationMap.getResults(); + auto lastDim = mlir::dyn_cast(dimExpr.back()); + assert(mlir::isa(lastDim)); + + SmallVector affineExprs; + affineExprs.push_back(lastDim); + auto destAffineMap = AffineMap::get(tensorType.getRank(), 0, affineExprs, + rewriter.getContext()); + SmallVector inBounds(1, true); + if (mlir::isa(op)) { + auto transferWriteOp = mlir::dyn_cast(op); + transferWriteOp.setPermutationMap(destAffineMap); + transferWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); + } else if (mlir::isa(op)) { + auto transferReadOp = mlir::dyn_cast(op); + transferReadOp.setPermutationMap(destAffineMap); + transferReadOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); + } +} + +// scf.for yield helper function +void maybeYieldValue(OpBuilder &b, Location loc, ValueRange value) { + bool hasRetVal = !value.empty(); + if (hasRetVal) { + b.create(loc, value); + } else { + b.create(loc); + } +} + +Type getScalarType(Operation *op) { + // Check that the operation type can be broken + // down into a loop. + auto baseType = getOperationVectorType(op); + if (failed(baseType)) { + LDBG("Failed to get vector type for operation: " << *op << "\n"); + assert(false && "Failed to get vector type for operation"); + return VectorType(); + } + auto vectorizedType = baseType.value(); + return VectorType::get({1}, vectorizedType.getElementType()); +} + +// __________________________________ +// Speical operations canonicalization +// __________________________________ + +mlir::FailureOr generateMultiReductionForLoop( + OpBuilder &opBuilder, vector::MultiDimReductionOp &multiReductionOp, + const llvm::SmallVector ¶llelAxis, const size_t parallelIdx, + const llvm::SmallVector &reductionAxis, + const size_t reductionIdx, const VectorType &vectorType, + llvm::SmallVector &inductionVars, const ValueRange &iterArgs, + Value &originalWriteResult, bool lastDimReduction) { + const int loop_step = getDataTypeMAXSIMDLength(vectorType); + auto loc = multiReductionOp->getLoc(); + auto zero = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr(opBuilder.getIndexType(), 0)); + + scf::ForOp forOp = nullptr; + // parallel axis + if (parallelIdx < parallelAxis.size()) { + auto forSteps = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr(opBuilder.getIndexType(), 1)); + auto numIter = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr( + opBuilder.getIndexType(), + vectorType.getShape()[parallelAxis[parallelIdx]])); + // Create a loop and move vectorized operation into loops. + forOp = opBuilder.create( + loc, zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + if (parallelIdx == parallelAxis.size() - 1) { + // move original transfer_read operation into parallel axis loop + // body + // get read operation + Value multiReductionAcc = multiReductionOp.getAcc(); + auto accReadOp = + multiReductionAcc.getDefiningOp(); + assert(accReadOp && + " Not transfer_read operation. Current multireduction " + "operation may have wrong analysis IR."); + // get write operation + vector::TransferWriteOp accWriteOp = nullptr; + for (auto [idx, x] : llvm::enumerate( + multiReductionOp->getResults()[0].getUsers())) { + if (idx == 0 && mlir::isa(x)) { + accWriteOp = mlir::dyn_cast(x); + break; + } + } + accWriteOp->dump(); + std::cout << "write operation dump checking..." << std::endl; + assert(accWriteOp); + IRMapping accReadindiceMap; + // parallel + for (auto [idx, axis] : enumerate(parallelAxis)) { + accReadindiceMap.map(accReadOp.getIndices()[axis], + inductionVars[idx]); + } + IRRewriter bodyRewriter(b); + auto newAccReadOp = mlir::dyn_cast( + b.clone(*accReadOp, accReadindiceMap)); + bodyRewriter.replaceOp(accReadOp, newAccReadOp); + // constructe next for loop + auto accVal = b.create( + loc, opBuilder.getZeroAttr(vectorType.getElementType())); + ValueRange newIterArgs(accVal); + auto nxtFor = generateMultiReductionForLoop( + b, multiReductionOp, parallelAxis, parallelIdx + 1, + reductionAxis, reductionIdx, vectorType, inductionVars, + newIterArgs, originalWriteResult, lastDimReduction); + + // move original transfer_write into loop + auto accRes = nxtFor.value()->getResults()[0]; + + // replace the vector as the loop return vector value + IRMapping accWriteindiceMap; + accWriteindiceMap.map(accWriteOp.getOperand(0), accRes); + auto newAccWriteOp = mlir::dyn_cast( + b.clone(*accWriteOp, accWriteindiceMap)); + int offset = 2; + for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { + if (idx >= parallelAxis.size()) { + break; + } + newAccWriteOp->setOperand(idx + offset, inductionVar); + } + setOpVectorizationPermutationMap(newAccWriteOp, opBuilder, + newAccWriteOp->getResult(0) + .getType() + .dyn_cast(), + newAccWriteOp.getPermutationMap()); + originalWriteResult = accWriteOp->getResult(0); + bodyRewriter.replaceOp(accWriteOp, newAccWriteOp); + maybeYieldValue(b, loc, newAccWriteOp->getResults()); + } else { + auto nxtFor = generateMultiReductionForLoop( + b, multiReductionOp, parallelAxis, parallelIdx + 1, + reductionAxis, reductionIdx, vectorType, inductionVars, + iterArgs, originalWriteResult, lastDimReduction); + maybeYieldValue(b, loc, nxtFor.value()->getResults()); + } + }); + + } else { + + auto forSteps = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr( + opBuilder.getIndexType(), + reductionIdx == reductionAxis.size() - 1 ? loop_step : 1)); + auto numIter = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr(opBuilder.getIndexType(), + vectorType.getShape()[reductionIdx])); + forOp = opBuilder.create( + loc, zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + + if (reductionIdx == reductionAxis.size() - 1) { + auto source = multiReductionOp->getOperand(0); + + auto readOp = + mlir::dyn_cast(source.getDefiningOp()); + assert(readOp); + IRMapping indiceMap; + IRRewriter rewriter(b); + auto clonedOp = b.clone(*readOp, indiceMap); + int offset = 1; + auto newReadOp = mlir::dyn_cast(clonedOp); + + for (auto [idx, inductionVar] : enumerate(inductionVars)) { + newReadOp->setOperand(idx + offset, inductionVar); + } + + auto newOperandType = lastDimReduction ? getVectorzedType(newReadOp) + : getScalarType(newReadOp); + newReadOp->getResult(0).setType(newOperandType); + setOpVectorizationPermutationMap( + newReadOp, b, + newReadOp.getSource().getType().dyn_cast(), + newReadOp.getPermutationMap()); + rewriter.replaceOp(readOp, newReadOp); + if (lastDimReduction) { + Operation *reductionOp = rewriter.create( + loc, multiReductionOp.getKind(), newReadOp->getResult(0), + iterArgs.back()); + maybeYieldValue(b, loc, reductionOp->getResults()); + } else { + auto reductionResult = + makeArithReduction(b, loc, multiReductionOp.getKind(), + newReadOp->getResult(0), iterArgs.back()); + maybeYieldValue(b, loc, reductionResult); + } + } else { + // outter loop + auto nxtFor = generateMultiReductionForLoop( + b, multiReductionOp, parallelAxis, parallelIdx, reductionAxis, + reductionIdx + 1, vectorType, inductionVars, iterArgs, + originalWriteResult, lastDimReduction); + maybeYieldValue(b, loc, nxtFor.value()->getResults()); + } + }); + } + return forOp; +} + +LogicalResult +canonicalizeReductionOperation(vector::MultiDimReductionOp &multiReductionOp, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + + auto srcVecType = multiReductionOp.getSourceVectorType(); + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + + // Separate reduction and parallel dims + bool lastDimReduction = false; + auto reductionAxisRange = + multiReductionOp.getReductionDims().getAsValueRange(); + auto reductionRange = llvm::to_vector<4>(llvm::map_range( + reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); + llvm::SmallVector reductionAxis(reductionRange.begin(), + reductionRange.end()); + llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), + reductionAxis.end()); + if (reductionAxisSet.contains(srcRank - 1)) { + lastDimReduction = true; + } + SmallVector parallelAxis; + for (int64_t i = 0; i < srcRank; ++i) { + if (!reductionAxisSet.contains(i)) { + parallelAxis.push_back(i); + } + } + /* + * The final IR may look like below: + * _for_(_fuseiter_i, 0, 1) + * sum = 0; + * _for_(_fuseiter_j, 0, 1) + * _for_(_fuseiter_k, 0, 1) + * sum += src[src_idx]; + * dst[dst_idx] = sum; + * */ + Operation *newReduction; + Value multiReductionAcc = multiReductionOp.getAcc(); + auto accTensorReadOp = + multiReductionAcc.getDefiningOp(); + Value originalWriteResult; + ValueRange iterArgs(accTensorReadOp->getOperand(0)); + llvm::SmallVector inductionVars; + auto forOp = generateMultiReductionForLoop( + rewriter, multiReductionOp, parallelAxis, 0, reductionAxis, 0, srcVecType, + inductionVars, iterArgs, originalWriteResult, lastDimReduction); + if (failed(forOp)) { + LDBG("MultiReduction Operation lowering failed"); + return failure(); + } + newReduction = forOp.value(); + rewriter.replaceAllUsesWith(originalWriteResult, + newReduction->getResults()[0]); + forOp->dump(); + + rewriter.replaceOp(multiReductionOp, newReduction); + return success(); +} + // 1. Classify operaions: // classify the operations into : // a. reorder, transpose. Reorder(or transpose) dim may bring data @@ -164,7 +462,8 @@ void VectorFusionStrategy::run() { enum CanonicalizerKind { OperationsGroup, Operations }; -struct CanonicalizerVectorOperation { +class CanonicalizerVectorOperation { +public: func::FuncOp func; IRRewriter rewriter; VectorFusionStrategy fusionStrategy; @@ -187,6 +486,7 @@ struct CanonicalizerVectorOperation { fusionStrategy.run(); } } + func::FuncOp getFunc() { return func; }; void generateGroupOpVectorizedIR( const int idx, std::queue &grp, @@ -200,60 +500,77 @@ struct CanonicalizerVectorOperation { llvm::SmallVector, 8> &opGroups, llvm::DenseMap &opGroupIndexMap); + void canonicalizeSpecialOperation(); + void run(); }; +void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { + llvm::SetVector reductionOps; + func->walk( + [&](vector::MultiDimReductionOp multiReductionOp) { + reductionOps.insert(multiReductionOp); + }); + for (auto x : reductionOps) { + IRRewriter rewriter(x); + (void)canonicalizeReductionOperation(x, rewriter); + } +} + void CanonicalizerVectorOperation::run() { - // 1. Analysis the operation's operands and results - // We need to analyze which operation results are needed by other - // operations, and we need to pass these results correctly. Mapping the - // operation result value to forloop yeild result value. We can replace the - // operation operand as: map(operand, forloop yield result) -> operand = - // loop yield result We put all the operation result into this map. - - // 2.a. Find what results should be generated by current group for - // using as operands to other operations? - - // Traverse all operations. If the operand of operations in other groups or - // outside the group is the result of the current group operation, then the - // current operation needs to generate a result. We use `setvector` to save - // the results that need to be generated by the current group. - - // 2.b. What operands are needed to find in the current group, and where - // can they be obtained ? - - // Thanks to 2.a, we get the result generated by the operations of - // each group, and this result will use `for loop yield` to generate a - // new result. Since the scope of the parent block of mlir is covered - // the current operation, the current operation does not need to pass these - // `for loop results` to the `iter args` of the required `for loop`. It - // only needs to replace the operand of the current operation with the - // corresponding `for loop yield result`. - - // However, for some operations that are not DPS, we need to canonicalize - // them. Canonicalization means that the operand of this operation is a - // vector but we can't get this vector due to it locates in another block - // which has a different scope. Therefore, it is necessary to write the - // vector results into a temporary tensor to save it. Then the vector needs - // to be read from the tensor before the current operation operate on it. - // Therefore, `empty tensor`, `transfer_write` and `transfer_read` need to - // be inserted at target place. - - // Query groupResultYeildSet to map operaion result value to scf.yield - // result value. + if (kind == CanonicalizerKind::OperationsGroup) { + // 1. Analysis the operation's operands and results + // We need to analyze which operation results are needed by other + // operations, and we need to pass these results correctly. Mapping the + // operation result value to forloop yeild result value. We can replace the + // operation operand as: map(operand, forloop yield result) -> operand = + // loop yield result We put all the operation result into this map. + + // 1.a. Find what results should be generated by current group for + // using as operands to other operations? + + // Traverse all operations. If the operand of operations in other groups or + // outside the group is the result of the current group operation, then the + // current operation needs to generate a result. We use `setvector` to save + // the results that need to be generated by the current group. + + // 1.b. What operands are needed to find in the current group, and where + // can they be obtained ? + + // Thanks to 2.a, we get the result generated by the operations of + // each group, and this result will use `for loop yield` to generate a + // new result. Since the scope of the parent block of mlir is covered + // the current operation, the current operation does not need to pass these + // `for loop results` to the `iter args` of the required `for loop`. It + // only needs to replace the operand of the current operation with the + // corresponding `for loop yield result`. + + // However, for some operations that are not DPS, we need to canonicalize + // them. Canonicalization means that the operand of this operation is a + // vector but we can't get this vector due to it locates in another block + // which has a different scope. Therefore, it is necessary to write the + // vector results into a temporary tensor to save it. Then the vector needs + // to be read from the tensor before the current operation operate on it. + // Therefore, `empty tensor`, `transfer_write` and `transfer_read` need to + // be inserted at target place. + + // Query groupResultYeildSet to map operaion result value to scf.yield + // result value. analysisGroupOperaionOperandsResults(fusionStrategy.getOpGroups(), fusionStrategy.getOpGroupIndexMap()); - // 3.Generate vectorized IR for each operation group + + // Speical Operation Canonicalization + // canonicalizeSpecialOperation(); + // 2.Generate vectorized IR for each operation group for (auto [idx, grp] : llvm::enumerate(fusionStrategy.getOpGroups())) { generateGroupOpVectorizedIR(idx, grp, fusionStrategy.getOpGroupIndexMap()); } - // 4. Some IR cleanup work + // 3. Some IR cleanup work DominanceInfo domInfo; - auto reWriter = IRRewriter(func); - eliminateCommonSubExpressions(reWriter, domInfo, func); + eliminateCommonSubExpressions(rewriter, domInfo, func); } else { // TODO: need to add directly canonicalize operations // generateGroupOpVectorizedIR(idx, grp, fusionStrategy.opGroupIndexMap); @@ -276,42 +593,6 @@ void CanonicalizerVectorOperation::run() { return true; } -// Since we rewrite transfer_read and transfer_write, the `permutationmap` must -// be changed. -void setOpVectorizationPermutationMap(Operation *op, IRRewriter &rewriter, - const RankedTensorType &tensorType, - const AffineMap &permutationMap) { - - auto dimExpr = permutationMap.getResults(); - auto lastDim = mlir::dyn_cast(dimExpr.back()); - assert(mlir::isa(lastDim)); - - SmallVector affineExprs; - affineExprs.push_back(lastDim); - auto destAffineMap = AffineMap::get(tensorType.getRank(), 0, affineExprs, - rewriter.getContext()); - SmallVector inBounds(1, true); - if (mlir::isa(op)) { - auto transferWriteOp = mlir::dyn_cast(op); - transferWriteOp.setPermutationMap(destAffineMap); - transferWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); - } else if (mlir::isa(op)) { - auto transferReadOp = mlir::dyn_cast(op); - transferReadOp.setPermutationMap(destAffineMap); - transferReadOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); - } -} - -// scf.for yield helper function -void maybeYieldValue(OpBuilder &b, Location loc, ValueRange value) { - bool hasRetVal = !value.empty(); - if (hasRetVal) { - b.create(loc, value); - } else { - b.create(loc); - } -} - // void checkAndSetOperand( Operation *op, const ValueRange &iterArgs, @@ -606,7 +887,7 @@ createTransferReadOpBefore(Operation *op, const Value &operand, auto srcReadOpAffineMap = srcReadOp->getPermutationMap(); // result of read operation should be same as operand auto t = rewriter.create( - rewriter.getUnknownLoc(), + op->getLoc(), /*vectorType=*/ VectorType::get(resultType.getShape(), resultType.getElementType()), /*source=*/operand, @@ -618,7 +899,7 @@ createTransferReadOpBefore(Operation *op, const Value &operand, } else { SmallVector inBoundsVal(operandType.getRank(), true); auto t = rewriter.create( - rewriter.getUnknownLoc(), + op->getLoc(), /*vectorType=*/ VectorType::get(operandType.getShape(), operandType.getElementType()), /*source=*/operand, @@ -665,20 +946,6 @@ void rewriteOperationAsVectorize( llvm::DenseMap &opPermuationMap) { std::queue transformQueue(groupOps); - auto getVectorzedType = [](Operation *op) -> VectorType { - // Check that the operation type can be broken - // down into a loop. - auto baseType = getOperationVectorType(op); - if (failed(baseType)) { - LDBG("Failed to get vector type for operation: " << *op << "\n"); - assert(false && "Failed to get vector type for operation"); - return VectorType(); - } - auto vectorizedType = baseType.value(); - const int loop_step = getDataTypeMAXSIMDLength(vectorizedType); - return VectorType::get({loop_step}, vectorizedType.getElementType()); - }; - while (!transformQueue.empty()) { auto op = transformQueue.front(); transformQueue.pop(); @@ -820,9 +1087,9 @@ mlir::FailureOr getOperationDestnationOperand(Operation *op) { // TODO: need to rewrite reduce // llvm::SmallVector & -// getReductionDims(vector::MultiDimReductionOp &reductionOp, +// getreductionAxis(vector::MultiDimReductionOp &reductionOp, // llvm::SmallVector &rdDims) { -// auto rdDimsAttr = reductionOp.getReductionDims().getValue(); +// auto rdDimsAttr = reductionOp.getreductionAxis().getValue(); // for (auto x : rdDimsAttr) { // rdDims.emplace_back(x.cast().getInt()); // } @@ -1035,7 +1302,6 @@ struct CPUPhysicalRegisterPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); auto func = getOperation(); - // canonicalize vector operation CanonicalizerVectorOperation canonicalizer(func); canonicalizer.run(); From e04eaf60374cc35f7d43682f68640f2b1675481b Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 4 Jun 2024 15:20:27 +0800 Subject: [PATCH 06/66] update reduce operation --- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 768 +++++++++++++----- lib/gc/Transforms/LowerTileVectorPass.cpp | 6 +- 2 files changed, 558 insertions(+), 216 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 8ab467960..7d9890cd6 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -187,231 +187,410 @@ Type getScalarType(Operation *op) { // Speical operations canonicalization // __________________________________ -mlir::FailureOr generateMultiReductionForLoop( +//===----------------------------------------------------------------------===// +// MultiReduce Operation +//===----------------------------------------------------------------------===// + +enum class MultiReduceOpAxisKind { Reduction, Parallel }; +void updateReduceReadWriteOperationOperand( + const llvm::SmallVector &inductionVars, + const llvm::SmallVector ¶llelAxis, Operation *op, + MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel) { + int indiceOffset = mlir::isa(op) ? 1 : 2; + for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { + if (rdKind == MultiReduceOpAxisKind::Parallel && + idx >= parallelAxis.size()) { + break; + } + op->setOperand(idx + indiceOffset, inductionVar); + } +} + +vector::TransferReadOp makeNewTransferReadOp( + Value &source, OpBuilder &b, IRMapping &readMap, + const llvm::SmallVector ¶llelAxis, + llvm::SmallVector &inductionVars, bool lastDimReduction, + MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel) { + IRRewriter rewriter(b); + auto readOp = mlir::dyn_cast(source.getDefiningOp()); + assert(readOp && " Not transfer_read operation. Current multireduction " + "operation may have wrong analysis IR."); + + auto clonedOp = b.clone(*readOp, readMap); + auto newReadOp = mlir::dyn_cast(clonedOp); + updateReduceReadWriteOperationOperand(inductionVars, parallelAxis, newReadOp, + rdKind); + + // modify the type of the new read operation + auto newOperandType = + (lastDimReduction && rdKind == MultiReduceOpAxisKind::Reduction) + ? getVectorzedType(newReadOp) + : getScalarType(newReadOp); + newReadOp->getResult(0).setType(newOperandType); + setOpVectorizationPermutationMap( + newReadOp, b, + newReadOp.getSource().getType().dyn_cast(), + newReadOp.getPermutationMap()); + + rewriter.replaceOp(readOp, newReadOp); + return newReadOp; +} + +vector::TransferWriteOp +makeNewTransferWriteOp(Value source, IRMapping &writeMap, OpBuilder &b, + const llvm::SmallVector ¶llelAxis, + llvm::SmallVector &inductionVars) { + IRRewriter bodyRewriter(b); + auto writeOp = source.getDefiningOp(); + auto newWriteOp = + mlir::dyn_cast(b.clone(*writeOp, writeMap)); + updateReduceReadWriteOperationOperand(inductionVars, parallelAxis, newWriteOp, + MultiReduceOpAxisKind::Parallel); + setOpVectorizationPermutationMap( + newWriteOp, b, + newWriteOp->getResult(0).getType().dyn_cast(), + newWriteOp.getPermutationMap()); + bodyRewriter.replaceOp(writeOp, newWriteOp); + return newWriteOp; +} + +scf::ForOp reductionAxisGenerateForLoop( OpBuilder &opBuilder, vector::MultiDimReductionOp &multiReductionOp, - const llvm::SmallVector ¶llelAxis, const size_t parallelIdx, const llvm::SmallVector &reductionAxis, const size_t reductionIdx, const VectorType &vectorType, llvm::SmallVector &inductionVars, const ValueRange &iterArgs, - Value &originalWriteResult, bool lastDimReduction) { - const int loop_step = getDataTypeMAXSIMDLength(vectorType); - auto loc = multiReductionOp->getLoc(); + bool lastDimReduction, Location &loc, const int loopStep) { + auto zero = opBuilder.create( loc, opBuilder.getIndexType(), opBuilder.getIntegerAttr(opBuilder.getIndexType(), 0)); + auto forSteps = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr( + opBuilder.getIndexType(), + (reductionIdx == reductionAxis.size() - 1 && lastDimReduction) + ? loopStep + : 1)); + auto numIter = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr(opBuilder.getIndexType(), + vectorType.getShape()[reductionIdx])); + auto forOp = opBuilder.create( + loc, zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); - scf::ForOp forOp = nullptr; - // parallel axis - if (parallelIdx < parallelAxis.size()) { - auto forSteps = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr(opBuilder.getIndexType(), 1)); - auto numIter = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr( - opBuilder.getIndexType(), - vectorType.getShape()[parallelAxis[parallelIdx]])); - // Create a loop and move vectorized operation into loops. - forOp = opBuilder.create( - loc, zero, numIter, forSteps, iterArgs, - [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { - inductionVars.emplace_back(iv); - if (parallelIdx == parallelAxis.size() - 1) { - // move original transfer_read operation into parallel axis loop - // body - // get read operation - Value multiReductionAcc = multiReductionOp.getAcc(); - auto accReadOp = - multiReductionAcc.getDefiningOp(); - assert(accReadOp && - " Not transfer_read operation. Current multireduction " - "operation may have wrong analysis IR."); - // get write operation - vector::TransferWriteOp accWriteOp = nullptr; - for (auto [idx, x] : llvm::enumerate( - multiReductionOp->getResults()[0].getUsers())) { - if (idx == 0 && mlir::isa(x)) { - accWriteOp = mlir::dyn_cast(x); - break; - } - } - accWriteOp->dump(); - std::cout << "write operation dump checking..." << std::endl; - assert(accWriteOp); - IRMapping accReadindiceMap; - // parallel - for (auto [idx, axis] : enumerate(parallelAxis)) { - accReadindiceMap.map(accReadOp.getIndices()[axis], - inductionVars[idx]); - } - IRRewriter bodyRewriter(b); - auto newAccReadOp = mlir::dyn_cast( - b.clone(*accReadOp, accReadindiceMap)); - bodyRewriter.replaceOp(accReadOp, newAccReadOp); - // constructe next for loop - auto accVal = b.create( - loc, opBuilder.getZeroAttr(vectorType.getElementType())); - ValueRange newIterArgs(accVal); - auto nxtFor = generateMultiReductionForLoop( - b, multiReductionOp, parallelAxis, parallelIdx + 1, - reductionAxis, reductionIdx, vectorType, inductionVars, - newIterArgs, originalWriteResult, lastDimReduction); - - // move original transfer_write into loop - auto accRes = nxtFor.value()->getResults()[0]; - - // replace the vector as the loop return vector value - IRMapping accWriteindiceMap; - accWriteindiceMap.map(accWriteOp.getOperand(0), accRes); - auto newAccWriteOp = mlir::dyn_cast( - b.clone(*accWriteOp, accWriteindiceMap)); - int offset = 2; - for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { - if (idx >= parallelAxis.size()) { - break; - } - newAccWriteOp->setOperand(idx + offset, inductionVar); - } - setOpVectorizationPermutationMap(newAccWriteOp, opBuilder, - newAccWriteOp->getResult(0) - .getType() - .dyn_cast(), - newAccWriteOp.getPermutationMap()); - originalWriteResult = accWriteOp->getResult(0); - bodyRewriter.replaceOp(accWriteOp, newAccWriteOp); - maybeYieldValue(b, loc, newAccWriteOp->getResults()); + if (reductionIdx == reductionAxis.size() - 1) { + IRRewriter rewriter(b); + IRMapping readMap; + Value reductionTarget = multiReductionOp->getOperand(0); + llvm::SmallVector parallelAxis; + auto newReadOp = makeNewTransferReadOp( + reductionTarget, b, readMap, parallelAxis, inductionVars, + lastDimReduction, MultiReduceOpAxisKind::Reduction); + + // reduction or elementwise reduce + if (lastDimReduction) { + Operation *reductionOp = rewriter.create( + loc, multiReductionOp.getKind(), newReadOp->getResult(0), + loopState.back()); + maybeYieldValue(b, loc, reductionOp->getResults()); } else { - auto nxtFor = generateMultiReductionForLoop( - b, multiReductionOp, parallelAxis, parallelIdx + 1, - reductionAxis, reductionIdx, vectorType, inductionVars, - iterArgs, originalWriteResult, lastDimReduction); - maybeYieldValue(b, loc, nxtFor.value()->getResults()); + auto reductionResult = + makeArithReduction(b, loc, multiReductionOp.getKind(), + newReadOp->getResult(0), loopState.back()); + maybeYieldValue(b, loc, reductionResult); } - }); - - } else { + } else { + // outter loop + auto nxtFor = reductionAxisGenerateForLoop( + b, multiReductionOp, reductionAxis, reductionIdx + 1, vectorType, + inductionVars, loopState, lastDimReduction, loc, loopStep); + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); - auto forSteps = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr( - opBuilder.getIndexType(), - reductionIdx == reductionAxis.size() - 1 ? loop_step : 1)); - auto numIter = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr(opBuilder.getIndexType(), - vectorType.getShape()[reductionIdx])); - forOp = opBuilder.create( - loc, zero, numIter, forSteps, iterArgs, - [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { - inductionVars.emplace_back(iv); - - if (reductionIdx == reductionAxis.size() - 1) { - auto source = multiReductionOp->getOperand(0); - - auto readOp = - mlir::dyn_cast(source.getDefiningOp()); - assert(readOp); - IRMapping indiceMap; - IRRewriter rewriter(b); - auto clonedOp = b.clone(*readOp, indiceMap); - int offset = 1; - auto newReadOp = mlir::dyn_cast(clonedOp); - - for (auto [idx, inductionVar] : enumerate(inductionVars)) { - newReadOp->setOperand(idx + offset, inductionVar); - } + return forOp; +} - auto newOperandType = lastDimReduction ? getVectorzedType(newReadOp) - : getScalarType(newReadOp); - newReadOp->getResult(0).setType(newOperandType); - setOpVectorizationPermutationMap( - newReadOp, b, - newReadOp.getSource().getType().dyn_cast(), - newReadOp.getPermutationMap()); - rewriter.replaceOp(readOp, newReadOp); - if (lastDimReduction) { - Operation *reductionOp = rewriter.create( - loc, multiReductionOp.getKind(), newReadOp->getResult(0), - iterArgs.back()); - maybeYieldValue(b, loc, reductionOp->getResults()); - } else { - auto reductionResult = - makeArithReduction(b, loc, multiReductionOp.getKind(), - newReadOp->getResult(0), iterArgs.back()); - maybeYieldValue(b, loc, reductionResult); +scf::ForOp parallelAxisGenerateForLoop( + OpBuilder &opBuilder, vector::MultiDimReductionOp &multiReductionOp, + const llvm::SmallVector ¶llelAxis, const size_t parallelIdx, + const llvm::SmallVector &reductionAxis, + const size_t reductionIdx, const VectorType &vectorType, + llvm::SmallVector &inductionVars, const ValueRange &iterArgs, + Value &originalWriteResult, bool lastDimReduction, Location &loc, + const int loopStep) { + auto zero = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr(opBuilder.getIndexType(), 0)); + auto forSteps = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr(opBuilder.getIndexType(), 1)); + auto numIter = opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr( + opBuilder.getIndexType(), + vectorType.getShape()[parallelAxis[parallelIdx]])); + // Create a loop and move vectorized operation into loops. + auto forOp = opBuilder.create( + loc, zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + if (parallelIdx == parallelAxis.size() - 1) { + + // read operation + IRMapping accReadMap; + auto multiReductionAcc = multiReductionOp.getAcc(); + auto accReadOp = multiReductionAcc.getDefiningOp(); + accReadMap.map(accReadOp->getOperand(0), loopState.back()); + + auto newAccReadOp = makeNewTransferReadOp( + multiReductionAcc, b, accReadMap, parallelAxis, inductionVars, + lastDimReduction, MultiReduceOpAxisKind::Parallel); + // constructe next for loop + auto accVal = b.create( + loc, opBuilder.getZeroAttr(vectorType.getElementType())); + ValueRange newIterArgs(accVal); + auto nxtFor = reductionAxisGenerateForLoop( + b, multiReductionOp, reductionAxis, reductionIdx, vectorType, + inductionVars, newIterArgs, lastDimReduction, loc, loopStep); + + // insert accumulate value to original vector + auto accRes = nxtFor->getResults()[0]; + auto insertOp = b.create( + loc, accRes, newAccReadOp->getResult(0), 0); + + // write vector back to tensor + vector::TransferWriteOp accWriteOp = nullptr; + for (auto [idx, x] : + llvm::enumerate(multiReductionOp->getResults()[0].getUsers())) { + if (idx == 0 && mlir::isa(x)) { + accWriteOp = mlir::dyn_cast(x); + break; } - } else { - // outter loop - auto nxtFor = generateMultiReductionForLoop( - b, multiReductionOp, parallelAxis, parallelIdx, reductionAxis, - reductionIdx + 1, vectorType, inductionVars, iterArgs, - originalWriteResult, lastDimReduction); - maybeYieldValue(b, loc, nxtFor.value()->getResults()); } - }); - } + assert(accWriteOp && + " Not transfer_write operation. Current multireduction " + "operation may have wrong analysis IR."); + IRMapping accWriteindiceMap; + accWriteindiceMap.map(accWriteOp.getOperand(0), + insertOp->getResults()[0]); + auto writeResult = accWriteOp->getResults()[0]; + auto newAccWriteOp = makeNewTransferWriteOp( + writeResult, accWriteindiceMap, b, parallelAxis, inductionVars); + originalWriteResult = newAccWriteOp->getResult(0); + + maybeYieldValue(b, loc, newAccWriteOp->getResults()); + } else { + auto nxtFor = parallelAxisGenerateForLoop( + b, multiReductionOp, parallelAxis, parallelIdx + 1, reductionAxis, + reductionIdx, vectorType, inductionVars, loopState, + originalWriteResult, lastDimReduction, loc, loopStep); + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); return forOp; } -LogicalResult -canonicalizeReductionOperation(vector::MultiDimReductionOp &multiReductionOp, - IRRewriter &rewriter) { - OpBuilder::InsertionGuard guard(rewriter); - - auto srcVecType = multiReductionOp.getSourceVectorType(); - auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - - // Separate reduction and parallel dims - bool lastDimReduction = false; - auto reductionAxisRange = - multiReductionOp.getReductionDims().getAsValueRange(); - auto reductionRange = llvm::to_vector<4>(llvm::map_range( - reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); - llvm::SmallVector reductionAxis(reductionRange.begin(), - reductionRange.end()); - llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), - reductionAxis.end()); - if (reductionAxisSet.contains(srcRank - 1)) { - lastDimReduction = true; - } - SmallVector parallelAxis; - for (int64_t i = 0; i < srcRank; ++i) { - if (!reductionAxisSet.contains(i)) { - parallelAxis.push_back(i); - } - } - /* - * The final IR may look like below: - * _for_(_fuseiter_i, 0, 1) - * sum = 0; - * _for_(_fuseiter_j, 0, 1) - * _for_(_fuseiter_k, 0, 1) - * sum += src[src_idx]; - * dst[dst_idx] = sum; - * */ - Operation *newReduction; - Value multiReductionAcc = multiReductionOp.getAcc(); - auto accTensorReadOp = - multiReductionAcc.getDefiningOp(); - Value originalWriteResult; - ValueRange iterArgs(accTensorReadOp->getOperand(0)); - llvm::SmallVector inductionVars; - auto forOp = generateMultiReductionForLoop( - rewriter, multiReductionOp, parallelAxis, 0, reductionAxis, 0, srcVecType, - inductionVars, iterArgs, originalWriteResult, lastDimReduction); - if (failed(forOp)) { - LDBG("MultiReduction Operation lowering failed"); - return failure(); - } - newReduction = forOp.value(); - rewriter.replaceAllUsesWith(originalWriteResult, - newReduction->getResults()[0]); - forOp->dump(); +scf::ForOp generateMultiReductionForLoop( + OpBuilder &opBuilder, vector::MultiDimReductionOp &multiReductionOp, + const llvm::SmallVector ¶llelAxis, const size_t parallelIdx, + const llvm::SmallVector &reductionAxis, + const size_t reductionIdx, const VectorType &vectorType, + llvm::SmallVector &inductionVars, const ValueRange &iterArgs, + Value &originalWriteResult, bool lastDimReduction) { + const int loopStep = getDataTypeMAXSIMDLength(vectorType); + auto loc = multiReductionOp->getLoc(); - rewriter.replaceOp(multiReductionOp, newReduction); - return success(); + scf::ForOp forOp = parallelAxisGenerateForLoop( + opBuilder, multiReductionOp, parallelAxis, parallelIdx, reductionAxis, + reductionIdx, vectorType, inductionVars, iterArgs, originalWriteResult, + lastDimReduction, loc, loopStep); + return forOp; } +// mlir::FailureOr generateTransposeForLoop( +// OpBuilder &opBuilder, vector::TransposeOp &transposeOp, +// const llvm::SmallVector ¶llelAxis, const size_t +// parallelIdx, const llvm::SmallVector &reductionAxis, const +// size_t reductionIdx, const VectorType &vectorType, +// llvm::SmallVector &inductionVars, const ValueRange &iterArgs, +// Value &originalWriteResult, bool lastDimReduction) { +// const int loop_step = getDataTypeMAXSIMDLength(vectorType); +// auto loc = transposeOp->getLoc(); +// auto zero = opBuilder.create( +// loc, opBuilder.getIndexType(), +// opBuilder.getIntegerAttr(opBuilder.getIndexType(), 0)); + +// scf::ForOp forOp = nullptr; +// // parallel axis +// if (parallelIdx < parallelAxis.size()) { +// auto forSteps = opBuilder.create( +// loc, opBuilder.getIndexType(), +// opBuilder.getIntegerAttr(opBuilder.getIndexType(), 1)); +// auto numIter = opBuilder.create( +// loc, opBuilder.getIndexType(), +// opBuilder.getIntegerAttr( +// opBuilder.getIndexType(), +// vectorType.getShape()[parallelAxis[parallelIdx]])); +// // Create a loop and move vectorized operation into loops. +// forOp = opBuilder.create( +// loc, zero, numIter, forSteps, iterArgs, +// [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { +// inductionVars.emplace_back(iv); +// if (parallelIdx == parallelAxis.size() - 1) { +// // move original transfer_read operation into parallel axis loop +// // body +// // get read operation +// Value multiReductionAcc = multiReductionOp.getAcc(); +// auto accReadOp = +// multiReductionAcc.getDefiningOp(); +// assert(accReadOp && +// " Not transfer_read operation. Current multireduction " +// "operation may have wrong analysis IR."); +// // get write operation +// vector::TransferWriteOp accWriteOp = nullptr; +// for (auto [idx, x] : llvm::enumerate( +// multiReductionOp->getResults()[0].getUsers())) { +// if (idx == 0 && mlir::isa(x)) { +// accWriteOp = mlir::dyn_cast(x); +// break; +// } +// } +// assert(accWriteOp); +// IRMapping accReadindiceMap; + +// IRRewriter bodyRewriter(b); +// auto newAccReadOp = mlir::dyn_cast( +// b.clone(*accReadOp, accReadindiceMap)); +// bodyRewriter.replaceOp(accReadOp, newAccReadOp); +// int offset = 1; +// for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { +// if (idx >= parallelAxis.size()) { +// break; +// } +// newAccReadOp->setOperand(idx + offset, inductionVar); +// } +// auto newOperandType = getScalarType(newAccReadOp); +// newAccReadOp->getResult(0).setType(newOperandType); +// setOpVectorizationPermutationMap( +// newAccReadOp, b, +// newAccReadOp.getSource().getType().dyn_cast(), +// newAccReadOp.getPermutationMap()); +// // constructe next for loop +// auto accVal = b.create( +// loc, opBuilder.getZeroAttr(vectorType.getElementType())); +// ValueRange newIterArgs(accVal); +// auto nxtFor = generateMultiReductionForLoop( +// b, multiReductionOp, parallelAxis, parallelIdx + 1, +// reductionAxis, reductionIdx, vectorType, inductionVars, +// newIterArgs, originalWriteResult, lastDimReduction); + +// // move original transfer_write into loop +// auto accRes = nxtFor.value()->getResults()[0]; + +// // replace the vector as the loop return vector value +// llvm::SmallVector insertPos; +// auto insertOp = b.create( +// loc, accRes, newAccReadOp->getResult(0), 0); +// IRMapping accWriteindiceMap; +// accWriteindiceMap.map(accWriteOp.getOperand(0), +// insertOp->getResults()[0]); +// auto newAccWriteOp = mlir::dyn_cast( +// b.clone(*accWriteOp, accWriteindiceMap)); +// offset = 2; +// for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { +// if (idx >= parallelAxis.size()) { +// break; +// } +// newAccWriteOp->setOperand(idx + offset, inductionVar); +// } +// setOpVectorizationPermutationMap(newAccWriteOp, b, +// newAccWriteOp->getResult(0) +// .getType() +// .dyn_cast(), +// newAccWriteOp.getPermutationMap()); +// bodyRewriter.replaceOp(accWriteOp, newAccWriteOp); +// originalWriteResult = newAccWriteOp->getResult(0); +// maybeYieldValue(b, loc, newAccWriteOp->getResults()); +// } else { +// auto nxtFor = generateMultiReductionForLoop( +// b, multiReductionOp, parallelAxis, parallelIdx + 1, +// reductionAxis, reductionIdx, vectorType, inductionVars, +// iterArgs, originalWriteResult, lastDimReduction); +// maybeYieldValue(b, loc, nxtFor.value()->getResults()); +// } +// }); + +// } else { + +// auto forSteps = opBuilder.create( +// loc, opBuilder.getIndexType(), +// opBuilder.getIntegerAttr( +// opBuilder.getIndexType(), +// reductionIdx == reductionAxis.size() - 1 ? loop_step : 1)); +// auto numIter = opBuilder.create( +// loc, opBuilder.getIndexType(), +// opBuilder.getIntegerAttr(opBuilder.getIndexType(), +// vectorType.getShape()[reductionIdx])); +// forOp = opBuilder.create( +// loc, zero, numIter, forSteps, iterArgs, +// [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { +// inductionVars.emplace_back(iv); + +// if (reductionIdx == reductionAxis.size() - 1) { +// auto source = multiReductionOp->getOperand(0); + +// auto readOp = +// mlir::dyn_cast(source.getDefiningOp()); +// assert(readOp); +// IRMapping indiceMap; +// IRRewriter rewriter(b); +// auto clonedOp = b.clone(*readOp, indiceMap); +// int offset = 1; +// auto newReadOp = +// mlir::dyn_cast(clonedOp); + +// for (auto [idx, inductionVar] : enumerate(inductionVars)) { +// newReadOp->setOperand(idx + offset, inductionVar); +// } + +// auto newOperandType = lastDimReduction ? +// getVectorzedType(newReadOp) +// : +// getScalarType(newReadOp); +// newReadOp->getResult(0).setType(newOperandType); +// setOpVectorizationPermutationMap( +// newReadOp, b, +// newReadOp.getSource().getType().dyn_cast(), +// newReadOp.getPermutationMap()); +// rewriter.replaceOp(readOp, newReadOp); +// if (lastDimReduction) { +// Operation *reductionOp = rewriter.create( +// loc, multiReductionOp.getKind(), newReadOp->getResult(0), +// loopState.back()); +// maybeYieldValue(b, loc, reductionOp->getResults()); +// } else { +// auto reductionResult = +// makeArithReduction(b, loc, multiReductionOp.getKind(), +// newReadOp->getResult(0), +// loopState.back()); +// maybeYieldValue(b, loc, reductionResult); +// } +// } else { +// // outter loop +// auto nxtFor = generateMultiReductionForLoop( +// b, multiReductionOp, parallelAxis, parallelIdx, +// reductionAxis, reductionIdx + 1, vectorType, inductionVars, +// iterArgs, originalWriteResult, lastDimReduction); +// maybeYieldValue(b, loc, nxtFor.value()->getResults()); +// } +// }); +// } +// return forOp; +// } + // 1. Classify operaions: // classify the operations into : // a. reorder, transpose. Reorder(or transpose) dim may bring data @@ -501,20 +680,159 @@ class CanonicalizerVectorOperation { llvm::DenseMap &opGroupIndexMap); void canonicalizeSpecialOperation(); + LogicalResult + canonicalizeReductionOperation(vector::MultiDimReductionOp &multiReductionOp, + IRRewriter &rewriter); + LogicalResult canonicalizeTransposeOperation(vector::TransposeOp &transposeOp, + IRRewriter &rewriter); void run(); + +private: + llvm::SetVector multiReductionOps; + llvm::SetVector shapeCastOps; }; +// LogicalResult CanonicalizerVectorOperation::canonicalizeTransposeOperation( +// vector::TransposeOp &transposeOp, IRRewriter &rewriter) { +// OpBuilder::InsertionGuard guard(rewriter); + +// auto srcVecType = multiReductionOp.getSourceVectorType(); +// auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + +// // Separate reduction and parallel dims +// bool lastDimReduction = false; +// auto reductionAxisRange = +// multiReductionOp.getReductionDims().getAsValueRange(); +// auto reductionRange = llvm::to_vector<4>(llvm::map_range( +// reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); +// llvm::SmallVector reductionAxis(reductionRange.begin(), +// reductionRange.end()); +// llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), +// reductionAxis.end()); +// if (reductionAxisSet.contains(srcRank - 1)) { +// lastDimReduction = true; +// } +// SmallVector parallelAxis; +// for (int64_t i = 0; i < srcRank; ++i) { +// if (!reductionAxisSet.contains(i)) { +// parallelAxis.push_back(i); +// } +// } +// /* +// * The final IR may look like below: +// * _for_(_fuseiter_i, 0, 1) +// * sum = 0; +// * _for_(_fuseiter_j, 0, 1) +// * _for_(_fuseiter_k, 0, 1) +// * sum += src[src_idx]; +// * dst[dst_idx] = sum; +// * */ +// Operation *newReduction; +// Value multiReductionAcc = multiReductionOp.getAcc(); +// auto accTensorReadOp = +// multiReductionAcc.getDefiningOp(); +// Value originalWriteResult; +// ValueRange iterArgs(accTensorReadOp->getOperand(0)); +// llvm::SmallVector inductionVars; +// auto forOp = generateMultiReductionForLoop( +// rewriter, multiReductionOp, parallelAxis, 0, reductionAxis, 0, +// srcVecType, inductionVars, iterArgs, originalWriteResult, +// lastDimReduction); +// if (failed(forOp)) { +// LDBG("MultiReduction Operation lowering failed"); +// return failure(); +// } +// auto replaceIfFn = [&](OpOperand &use) { +// return use.getOwner()->getBlock() != +// originalWriteResult.getDefiningOp()->getBlock(); +// }; +// newReduction = forOp.value(); +// rewriter.replaceOpUsesWithIf(originalWriteResult.getDefiningOp(), +// newReduction->getResults()[0], replaceIfFn); + +// rewriter.replaceOp(multiReductionOp, newReduction); +// return success(); +// } + +LogicalResult CanonicalizerVectorOperation::canonicalizeReductionOperation( + vector::MultiDimReductionOp &multiReductionOp, IRRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + + auto srcVecType = multiReductionOp.getSourceVectorType(); + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + + // Separate reduction and parallel dims + bool lastDimReduction = false; + auto reductionAxisRange = + multiReductionOp.getReductionDims().getAsValueRange(); + auto reductionRange = llvm::to_vector<4>(llvm::map_range( + reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); + llvm::SmallVector reductionAxis(reductionRange.begin(), + reductionRange.end()); + llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), + reductionAxis.end()); + if (reductionAxisSet.contains(srcRank - 1)) { + lastDimReduction = true; + } + SmallVector parallelAxis; + for (int64_t i = 0; i < srcRank; ++i) { + if (!reductionAxisSet.contains(i)) { + parallelAxis.push_back(i); + } + } + /* + * The final IR may look like below: + * _for_(_fuseiter_i, 0, 1) + * sum = 0; + * _for_(_fuseiter_j, 0, 1) + * _for_(_fuseiter_k, 0, 1) + * sum += src[src_idx]; + * dst[dst_idx] = sum; + * */ + Operation *newReduction; + Value multiReductionAcc = multiReductionOp.getAcc(); + auto accTensorReadOp = + multiReductionAcc.getDefiningOp(); + Value originalWriteResult; + ValueRange iterArgs(accTensorReadOp->getOperand(0)); + llvm::SmallVector inductionVars; + auto forOp = generateMultiReductionForLoop( + rewriter, multiReductionOp, parallelAxis, 0, reductionAxis, 0, srcVecType, + inductionVars, iterArgs, originalWriteResult, lastDimReduction); + auto replaceIfFn = [&](OpOperand &use) { + return use.getOwner()->getBlock() != + originalWriteResult.getDefiningOp()->getBlock(); + }; + newReduction = forOp; + rewriter.replaceOpUsesWithIf(originalWriteResult.getDefiningOp(), + newReduction->getResults()[0], replaceIfFn); + + rewriter.replaceOp(multiReductionOp, newReduction); + return success(); +} + void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { - llvm::SetVector reductionOps; - func->walk( - [&](vector::MultiDimReductionOp multiReductionOp) { - reductionOps.insert(multiReductionOp); - }); - for (auto x : reductionOps) { + func->walk([&](Operation *op) { + llvm::TypeSwitch(op) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + multiReductionOps.insert(multiReductionOp); + }) + .Case([&](vector::ShapeCastOp shapeCastOp) { + shapeCastOps.insert(shapeCastOp); + }) + .Default([&](Operation *) {}); + }); + // process reduction + for (auto x : multiReductionOps) { IRRewriter rewriter(x); (void)canonicalizeReductionOperation(x, rewriter); } + // process shapecast + // for (auto x : shapeCastOps) { + // } + return; } void CanonicalizerVectorOperation::run() { @@ -561,9 +879,10 @@ void CanonicalizerVectorOperation::run() { fusionStrategy.getOpGroupIndexMap()); // Speical Operation Canonicalization - // canonicalizeSpecialOperation(); + canonicalizeSpecialOperation(); // 2.Generate vectorized IR for each operation group for (auto [idx, grp] : llvm::enumerate(fusionStrategy.getOpGroups())) { + generateGroupOpVectorizedIR(idx, grp, fusionStrategy.getOpGroupIndexMap()); } @@ -819,10 +1138,12 @@ Value setOutGroupOperationOperandResult(Operation *op, Type resultElementType = newOperandType.getElementType(); Attribute initValueAttr; - if (isa(resultElementType)) + if (isa(resultElementType)) { initValueAttr = FloatAttr::get(resultElementType, 0.0); - else + + } else { initValueAttr = IntegerAttr::get(resultElementType, 0); + } auto cntOp = rewriter.create( rewriter.getUnknownLoc(), DenseElementsAttr::get(newOperandType, {initValueAttr})); @@ -1263,6 +1584,19 @@ void updateLoopResultUses(llvm::SetVector &opResults, } } +bool hasSpecialOperation(std::queue &grp) { + std::queue tmpQ(grp); + while (!tmpQ.empty()) { + auto curOp = tmpQ.front(); + if (mlir::isa(curOp) or + mlir::isa(curOp)) { + return true; + } + tmpQ.pop(); + } + return false; +} + void CanonicalizerVectorOperation::generateGroupOpVectorizedIR( const int idx, std::queue &grp, llvm::DenseMap &opGroupIndexMap) { @@ -1270,6 +1604,10 @@ void CanonicalizerVectorOperation::generateGroupOpVectorizedIR( LDBG("Current operation Group is empty."); return; } + // TODO: special operation better fusion + if (hasSpecialOperation(grp)) { + return; + } auto getType = getOperationVectorType(grp.front()); if (failed(getType)) { LDBG("Failed to get vector type for operation: " << *grp.front() << "\n"); @@ -1302,7 +1640,7 @@ struct CPUPhysicalRegisterPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); auto func = getOperation(); - // canonicalize vector operation + // canonicalize vector operation, default use vector-based fusion strategy. CanonicalizerVectorOperation canonicalizer(func); canonicalizer.run(); } diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index 89b8e5354..f60e94e75 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -479,10 +479,14 @@ struct LowerTileVectorPass tensor::populateFoldTensorSubsetOpPatterns(patterns); tensor::populateFoldTensorEmptyPatterns(patterns, true); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - vector::VectorTransformsOptions vectorTransformOptions; + // vector::VectorTransformsOptions vectorTransformOptions; // vector::populateVectorMultiReductionLoweringPatterns( // patterns, vectorTransformOptions.vectorMultiReductionLowering); // vector::populateVectorShapeCastLoweringPatterns(patterns); + // vector::VectorTransformsOptions options; + // options.vectorTransposeLowering = + // vector::VectorTransposeLowering::Shuffle16x16; + // vector::populateVectorTransposeLoweringPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } From 75e55466ccd7d80fa44f95c0e87b4be08190c122 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Wed, 5 Jun 2024 18:47:31 +0800 Subject: [PATCH 07/66] update reduce --- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 286 +++++++++++++----- 1 file changed, 213 insertions(+), 73 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 7d9890cd6..82e904345 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -183,6 +183,63 @@ Type getScalarType(Operation *op) { return VectorType::get({1}, vectorizedType.getElementType()); } +Operation *createTensorEmptyBefore(Operation *op) { + auto rtType = op->getResultTypes()[0].dyn_cast(); + IRRewriter reWriter(op); + + SmallVector shapes; + SmallVector dynDims; + for (unsigned i = 0; i < rtType.getRank(); i++) { + shapes.push_back(rtType.getDimSize(i)); + if (rtType.isDynamicDim(i)) + dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), + op->getResult(0), i)); + } + return reWriter.create(op->getLoc(), rtType.getShape(), + rtType.getElementType(), dynDims); +} + +Operation * +createTransferReadOpBefore(Operation *op, const Value &operand, + vector::TransferReadOp *srcReadOp = nullptr) { + auto operandType = operand.getType().dyn_cast(); + + IRRewriter rewriter(op); + auto zero = + rewriter.create(rewriter.getUnknownLoc(), 0); + auto padValue = rewriter.create( + rewriter.getUnknownLoc(), + rewriter.getZeroAttr(operandType.getElementType())); + + if (srcReadOp) { + auto resultType = srcReadOp->getType().dyn_cast(); + SmallVector inBoundsVal(resultType.getRank(), true); + auto srcReadOpAffineMap = srcReadOp->getPermutationMap(); + // result of read operation should be same as operand + auto t = rewriter.create( + op->getLoc(), + /*vectorType=*/ + VectorType::get(resultType.getShape(), resultType.getElementType()), + /*source=*/operand, + /*indices=*/SmallVector(operandType.getRank(), zero), + /**affinemap*/ srcReadOpAffineMap, + /*inBounds=*/inBoundsVal); + + return t; + } else { + SmallVector inBoundsVal(operandType.getRank(), true); + auto t = rewriter.create( + op->getLoc(), + /*vectorType=*/ + VectorType::get(operandType.getShape(), operandType.getElementType()), + /*source=*/operand, + /*indices=*/SmallVector(operandType.getRank(), zero), + /**affinemap*/ padValue, + /*inBounds=*/inBoundsVal); + return t; + } +} + // __________________________________ // Speical operations canonicalization // __________________________________ @@ -290,17 +347,17 @@ scf::ForOp reductionAxisGenerateForLoop( lastDimReduction, MultiReduceOpAxisKind::Reduction); // reduction or elementwise reduce - if (lastDimReduction) { - Operation *reductionOp = rewriter.create( - loc, multiReductionOp.getKind(), newReadOp->getResult(0), - loopState.back()); - maybeYieldValue(b, loc, reductionOp->getResults()); - } else { - auto reductionResult = - makeArithReduction(b, loc, multiReductionOp.getKind(), - newReadOp->getResult(0), loopState.back()); - maybeYieldValue(b, loc, reductionResult); - } + // if (lastDimReduction) { + // Operation *reductionOp = rewriter.create( + // loc, multiReductionOp.getKind(), newReadOp->getResult(0), + // loopState.back()); + // maybeYieldValue(b, loc, reductionOp->getResults()); + // } else { + auto reductionResult = + makeArithReduction(b, loc, multiReductionOp.getKind(), + newReadOp->getResult(0), loopState.back()); + maybeYieldValue(b, loc, reductionResult); + // } } else { // outter loop auto nxtFor = reductionAxisGenerateForLoop( @@ -348,9 +405,21 @@ scf::ForOp parallelAxisGenerateForLoop( auto newAccReadOp = makeNewTransferReadOp( multiReductionAcc, b, accReadMap, parallelAxis, inductionVars, lastDimReduction, MultiReduceOpAxisKind::Parallel); + auto resultElementType = vectorType.getElementType(); // constructe next for loop + // auto accVal = b.create( + // loc, opBuilder.getZeroAttr(vectorType.getElementType())); + Attribute initValueAttr; + if (isa(resultElementType)) { + initValueAttr = FloatAttr::get(resultElementType, 0.0); + + } else { + initValueAttr = IntegerAttr::get(resultElementType, 0); + } auto accVal = b.create( - loc, opBuilder.getZeroAttr(vectorType.getElementType())); + loc, DenseElementsAttr::get(getVectorzedType(multiReductionOp), + {initValueAttr})); + ValueRange newIterArgs(accVal); auto nxtFor = reductionAxisGenerateForLoop( b, multiReductionOp, reductionAxis, reductionIdx, vectorType, @@ -358,8 +427,11 @@ scf::ForOp parallelAxisGenerateForLoop( // insert accumulate value to original vector auto accRes = nxtFor->getResults()[0]; + + Operation *reductionOp = b.create( + loc, multiReductionOp.getKind(), accRes); auto insertOp = b.create( - loc, accRes, newAccReadOp->getResult(0), 0); + loc, reductionOp->getResult(0), newAccReadOp->getResults()[0], 0); // write vector back to tensor vector::TransferWriteOp accWriteOp = nullptr; @@ -1064,13 +1136,134 @@ bool isCompatibleVectorType(Operation *op1, Operation *op2) { auto sp1 = type1.value(); auto sp2 = type2.value(); auto min_rank = std::min(sp1.getRank(), sp2.getRank()) - 1; - for (auto i = min_rank; i >= 0; i--) { + bool isCompatible = true; + // from front to back + for (long i = 0; i < min_rank; i++) { if (sp1.getDimSize(i) != sp2.getDimSize(i)) { - return false; + isCompatible = false; + break; } } - return true; + return isCompatible; +} + +/// which axis do the shape cast in source shape a +void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, + llvm::SmallVector &res) { + unsigned rankA = a.size(); + unsigned rankB = b.size(); + assert(rankA < rankB && "May be invalid shape cast operation."); + + auto isOne = [](int64_t v) { return v == 1; }; + + // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape + // casted to a 0-d vector. + if (rankA == 0 && llvm::all_of(b, isOne)) { + for (size_t i = 0; i < a.size(); i++) { + res.emplace_back(i); + } + return; + } + + unsigned i = 0; + unsigned j = 0; + while (i < rankA && j < rankB) { + int64_t dimA = a[i]; + int64_t dimB = 1; + int64_t bAxisBegin = j; + while (dimB < dimA && j < rankB) + dimB *= b[j++]; + if (dimA != dimB) { + assert(false && " Invalid shape cast operation."); + break; + } + if (bAxisBegin != j) { + res.emplace_back(i); + } + ++i; + + // Handle the case when trailing dimensions are of size 1. + // Include them into the contiguous sequence. + if (i < rankA && llvm::all_of(a.slice(i), isOne)) + i = rankA; + if (j < rankB && llvm::all_of(b.slice(j), isOne)) + j = rankB; + } + + assert(i == rankA && j == rankB && "Invalid shapecast operation."); +} + +void getOperationDataAxis(Operation *op, llvm::SmallVector &dataAxis) { + return TypeSwitch(op) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + auto rdDimsRange = multiReductionOp.getReductionDims() + .getAsValueRange(); + auto reductionDims = llvm::to_vector<4>(llvm::map_range( + rdDimsRange, [](const APInt &a) { return a.getZExtValue(); })); + dataAxis.assign(reductionDims.begin(), reductionDims.end()); + }) + .Case([&](vector::ShapeCastOp shapeCastOp) { + auto srcType = shapeCastOp.getSourceVectorType(); + auto dstType = shapeCastOp.getResultVectorType(); + auto srcShape = srcType.getShape(); + auto dstShape = dstType.getShape(); + shapeCastSourceAxis(srcShape, dstShape, dataAxis); + }) + .Default([&](Operation *op) { + // default is last axis + dataAxis.emplace_back( + op->getResultTypes().front().cast().getRank() - 1); + }); +} + +bool hasDataDependency(Operation *op1, Operation *op2) { + // op1 must be special operation + if (!isSpecialOp(op1)) { + return hasDataDependency(op2, op1); + } + auto hasSameAxis = [](const llvm::SmallVector &dims1, + const llvm::SmallVector &dims2) { + llvm::DenseSet checkSet(dims2.begin(), dims2.end()); + for (auto x : dims1) { + if (checkSet.contains(x)) { + return true; + } + } + return false; + }; + auto res = + TypeSwitch(op1) + .Case([&](vector::ShapeCastOp shapeCastOp) { + llvm::SmallVector dims1, dims2; + getOperationDataAxis(op1, dims1); + getOperationDataAxis(op2, dims2); + return hasSameAxis(dims1, dims2); + }) + .Case<>([&](vector::MultiDimReductionOp multiReductionOp) { + // op1 is special operation, op2 is normal operation + // op1 and op2 is both speicial operation + auto rdDimsRange = multiReductionOp.getReductionDims() + .getAsValueRange(); + auto reductionDims = llvm::to_vector( + llvm::map_range(rdDimsRange, [](const APInt &a) { + return (int64_t)a.getZExtValue(); + })); + llvm::SmallVector dims2; + getOperationDataAxis(op2, dims2); + llvm::DenseSet checkSet(dims2.begin(), dims2.end()); + + for (auto x : dims2) { + if (!checkSet.contains(x)) { + return true; + } + } + return false; + }) + .Default([&](Operation *op) { return false; }); + + return res; } bool isNeedNewGroup(llvm::SmallVector, 8> &opGroups, @@ -1084,6 +1277,10 @@ bool isNeedNewGroup(llvm::SmallVector, 8> &opGroups, } // previous operation is a special operation if (isSpecialOp(prevOp)) { + // special operation need to check data dependency axis + if (hasDataDependency(prevOp, op)) { + return true; + } return true; } // previous operation vector type is not compatible with current operation @@ -1174,63 +1371,6 @@ void setOperationOperandResult( } }; -Operation *createTensorEmptyBefore(Operation *op) { - auto rtType = op->getResultTypes()[0].dyn_cast(); - IRRewriter reWriter(op); - - SmallVector shapes; - SmallVector dynDims; - for (unsigned i = 0; i < rtType.getRank(); i++) { - shapes.push_back(rtType.getDimSize(i)); - if (rtType.isDynamicDim(i)) - dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), - op->getResult(0), i)); - } - return reWriter.create(op->getLoc(), rtType.getShape(), - rtType.getElementType(), dynDims); -} - -Operation * -createTransferReadOpBefore(Operation *op, const Value &operand, - vector::TransferReadOp *srcReadOp = nullptr) { - auto operandType = operand.getType().dyn_cast(); - - IRRewriter rewriter(op); - auto zero = - rewriter.create(rewriter.getUnknownLoc(), 0); - auto padValue = rewriter.create( - rewriter.getUnknownLoc(), - rewriter.getZeroAttr(operandType.getElementType())); - - if (srcReadOp) { - auto resultType = srcReadOp->getType().dyn_cast(); - SmallVector inBoundsVal(resultType.getRank(), true); - auto srcReadOpAffineMap = srcReadOp->getPermutationMap(); - // result of read operation should be same as operand - auto t = rewriter.create( - op->getLoc(), - /*vectorType=*/ - VectorType::get(resultType.getShape(), resultType.getElementType()), - /*source=*/operand, - /*indices=*/SmallVector(operandType.getRank(), zero), - /**affinemap*/ srcReadOpAffineMap, - /*inBounds=*/inBoundsVal); - - return t; - } else { - SmallVector inBoundsVal(operandType.getRank(), true); - auto t = rewriter.create( - op->getLoc(), - /*vectorType=*/ - VectorType::get(operandType.getShape(), operandType.getElementType()), - /*source=*/operand, - /*indices=*/SmallVector(operandType.getRank(), zero), - /**affinemap*/ padValue, - /*inBounds=*/inBoundsVal); - return t; - } -} - Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { auto rtType = op->getResultTypes()[0].dyn_cast(); auto rank = rtType.getRank(); From b80822b345710021fb66e97baa3b3fc5c99bc68e Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Sat, 8 Jun 2024 17:51:50 +0800 Subject: [PATCH 08/66] record --- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 2 +- .../gc/transforms/cpu-vetor-distribution.mlir | 188 ++++++++++++++++-- 2 files changed, 171 insertions(+), 19 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 82e904345..4a11fb65b 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -111,7 +111,7 @@ mlir::FailureOr getOperationVectorType(Operation *op) { [&](arith::ConstantOp constantOp) { return failure(); }) .Default([&](Operation *op) -> mlir::FailureOr { if (!op->getResults().empty()) { - auto t = op->getResultTypes().front().dyn_cast(); + auto t = mlir::dyn_cast(op->getResultTypes().front()); if (t) { return t; } diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index e4e39ec33..32ee8414f 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -1,11 +1,12 @@ // RUN: gc-opt --split-input-file --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all -- %s // CHECK-LABEL: func @add_tensor -// func.func @add_tensor_test0(%arg0: tensor<4x8x1024xf32>, %arg1: tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32> { -// %0 = tensor.empty() : tensor<4x8x1024xf32> -// %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x1024xf32>, tensor<4x8x1024xf32>) outs(%0: tensor<4x8x1024xf32>) -> tensor<4x8x1024xf32> -// return %1 : tensor<4x8x1024xf32> -// } +func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { + %0 = tensor.empty() : tensor<11008x4096xf32> + %1 = linalg.add ins(%arg0, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> + %2 = linalg.add ins(%1, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> + return %2 : tensor<11008x4096xf32> +} // func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, // %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) @@ -28,17 +29,168 @@ // func.return %relued : tensor<512x512xf32> // } -func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { - %0 = tensor.empty() : tensor<16x64xf32> - %reduce = linalg.reduce - ins(%arg0:tensor<16x32x64xf32>) - outs(%0:tensor<16x64xf32>) - dimensions = [1] - (%in: f32, %out: f32) { - %1 = arith.addf %out, %in: f32 - linalg.yield %1: f32 - } - %2 = tensor.expand_shape %reduce [[0],[1, 2]] : tensor<16x64xf32> into tensor<16x1x64xf32> - return %2 : tensor<16x1x64xf32> -} +// func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { +// %0 = tensor.empty() : tensor<16x64xf32> +// %reduce = linalg.reduce +// ins(%arg0:tensor<16x32x64xf32>) +// outs(%0:tensor<16x64xf32>) +// dimensions = [1] +// (%in: f32, %out: f32) { +// %1 = arith.addf %out, %in: f32 +// linalg.yield %1: f32 +// } +// %2 = tensor.expand_shape %reduce [[0],[1, 2]] : tensor<16x64xf32> into tensor<16x1x64xf32> +// return %2 : tensor<16x1x64xf32> +// } + +// func.func @insert_pad_into_fill(%input: tensor, %low0: index, %low1: index, %high1: index, %high2: index) -> tensor<8x384x384xf32> { +// %f0 = arith.constant 0.0 : f32 +// %c0 = arith.constant 0 : index +// %pad = tensor.pad %input low[%low0, %low1, %c0] high[%c0, %high1, %high2] { +// ^bb0(%arg3: index, %arg4: index, %arg5: index): +// tensor.yield %f0 : f32 +// } : tensor to tensor<8x128x128xf32> +// %empty = tensor.empty() : tensor<8x384x384xf32> +// %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> +// %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> +// return %0: tensor<8x384x384xf32> +// } + +// #map = affine_map<(d0) -> (d0 * 64)> +// #map1 = affine_map<(d0) -> (d0 * 128)> +// #map2 = affine_map<(d0) -> (d0 floordiv 16)> +// #map3 = affine_map<(d0) -> (d0 floordiv 32)> +// #map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> +// #map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> +// module { +// func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { +// %c32 = arith.constant 32 : index +// %c512 = arith.constant 512 : index +// %c128 = arith.constant 128 : index +// %c64 = arith.constant 64 : index +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : bf16 +// %0 = tensor.empty() : tensor<128x256xbf16> +// %1 = tensor.empty() : tensor<512x256xbf16> +// %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { +// %3 = affine.apply #map(%arg3) +// %4 = affine.apply #map1(%arg4) +// %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> +// %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> +// %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> +// %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> +// %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> +// %11 = affine.apply #map2(%arg28) +// %12 = affine.apply #map3(%arg24) +// %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> +// %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> +// %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> +// %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> +// %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> +// %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %15 = affine.apply #map4(%arg12, %arg24, %arg4) +// %16 = affine.apply #map5(%arg8, %arg20, %arg3) +// %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> +// %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> +// %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] +// %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.forall.in_parallel { +// tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// } +// } +// return %2#2 : tensor<128x256xbf16> +// } +// } + +// func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { +// %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> +// %1 = tensor.empty() : tensor<8192x16384xf32> +// %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { +// ^bb0(%in: f32, %in_0: f32, %out: f32): +// %4 = arith.addf %in, %in_0 : f32 +// linalg.yield %4 : f32 +// } -> tensor<8192x16384xf32> +// %c0 = arith.constant 0 : index +// %c8192 = arith.constant 8192 : index +// %c128 = arith.constant 128 : index +// %3 = scf.for %arg4 = %c0 to %c8192 step %c128 iter_args(%arg5 = %1) -> (tensor<8192x16384xf32>) { +// %c0_0 = arith.constant 0 : index +// %c16384 = arith.constant 16384 : index +// %c128_1 = arith.constant 128 : index +// %4 = scf.for %arg6 = %c0_0 to %c16384 step %c128_1 iter_args(%arg7 = %arg5) -> (tensor<8192x16384xf32>) { +// %extracted_slice = tensor.extract_slice %arg0[%arg4, 0] [128, 12288] [1, 1] : tensor<8192x12288xf16> to tensor<128x12288xf16> +// %extracted_slice_2 = tensor.extract_slice %arg1[0, %arg6] [12288, 128] [1, 1] : tensor<12288x16384xf16> to tensor<12288x128xf16> +// %extracted_slice_3 = tensor.extract_slice %arg2[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> +// %5 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice, %extracted_slice_2 : tensor<128x12288xf16>, tensor<12288x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf32>) -> tensor<128x128xf32> +// %extracted_slice_4 = tensor.extract_slice %0[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> +// %extracted_slice_5 = tensor.extract_slice %arg3[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> +// %extracted_slice_6 = tensor.extract_slice %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> +// %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_5 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%extracted_slice_6 : tensor<128x128xf32>) attrs = {__root_op__ = 0 : i64} { +// ^bb0(%in: f32, %in_9: f32, %out: f32): +// %8 = arith.addf %in, %in_9 : f32 +// linalg.yield %8 : f32 +// } -> tensor<128x128xf32> +// %c0_7 = arith.constant 0 : index +// %c128_8 = arith.constant 128 : index +// %c32 = arith.constant 32 : index +// %7 = scf.for %arg8 = %c0_7 to %c128_8 step %c32 iter_args(%arg9 = %extracted_slice_6) -> (tensor<128x128xf32>) { +// %c0_9 = arith.constant 0 : index +// %c128_10 = arith.constant 128 : index +// %c32_11 = arith.constant 32 : index +// %8 = scf.for %arg10 = %c0_9 to %c128_10 step %c32_11 iter_args(%arg11 = %arg9) -> (tensor<128x128xf32>) { +// %extracted_slice_12 = tensor.extract_slice %extracted_slice[%arg8, 0] [32, 12288] [1, 1] : tensor<128x12288xf16> to tensor<32x12288xf16> +// %extracted_slice_13 = tensor.extract_slice %extracted_slice_2[0, %arg10] [12288, 32] [1, 1] : tensor<12288x128xf16> to tensor<12288x32xf16> +// %extracted_slice_14 = tensor.extract_slice %extracted_slice_3[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> +// %9 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice_12, %extracted_slice_13 : tensor<32x12288xf16>, tensor<12288x32xf16>) outs(%extracted_slice_14 : tensor<32x32xf32>) -> tensor<32x32xf32> +// %extracted_slice_15 = tensor.extract_slice %5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> +// %extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> +// %extracted_slice_17 = tensor.extract_slice %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> +// %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %extracted_slice_16 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice_17 : tensor<32x32xf32>) attrs = {__root_op__ = 0 : i64} { +// ^bb0(%in: f32, %in_19: f32, %out: f32): +// %11 = arith.addf %in, %in_19 : f32 +// linalg.yield %11 : f32 +// } -> tensor<32x32xf32> +// %inserted_slice_18 = tensor.insert_slice %10 into %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x128xf32> +// scf.yield %inserted_slice_18 : tensor<128x128xf32> +// } {__parallel_loop__ = 1 : i64} +// scf.yield %8 : tensor<128x128xf32> +// } {__parallel_loop__ = 1 : i64} +// %inserted_slice = tensor.insert_slice %7 into %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<8192x16384xf32> +// scf.yield %inserted_slice : tensor<8192x16384xf32> +// } {__parallel_loop__ = 0 : i64} +// scf.yield %4 : tensor<8192x16384xf32> +// } {__parallel_loop__ = 0 : i64} +// return %3 : tensor<8192x16384xf32> +// } From a7f0e21f077fec2f767af53ae04ecc0b2dc7f5b5 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 25 Jun 2024 10:05:11 +0800 Subject: [PATCH 09/66] update reduce --- include/gc/Transforms/TilingVector.h | 266 +++ lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 1749 +++++++++-------- lib/gc/Transforms/LowerTileVectorPass.cpp | 276 ++- 3 files changed, 1373 insertions(+), 918 deletions(-) create mode 100644 include/gc/Transforms/TilingVector.h diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h new file mode 100644 index 000000000..0227a3794 --- /dev/null +++ b/include/gc/Transforms/TilingVector.h @@ -0,0 +1,266 @@ +//===- TilingVector.h - Graph Compiler passes -------------------------*- C++ +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef GC_PASSES_TILINGVECTOR_H +#define GC_PASSES_TILINGVECTOR_H + +#include "gc/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/ExecutionEngine/Float16bits.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include +#include +#include +#include +#include +#include +#include +namespace mlir { +namespace gc { +namespace { + +Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); +void rewriteOperationAsVectorize( + const std::queue &groupOps, + const llvm::DenseMap &opMap, OpBuilder &rewriter, + llvm::DenseMap &opPermuationMap); +void checkAndSetOperand( + Operation *op, const ValueRange &iterArgs, + const llvm::DenseMap &operandIdxMap, + const llvm::SmallVector &inductionVars, + const llvm::DenseMap &opPermuationMap); + +// 1. Classify operaions: +// classify the operations into : +// a. reorder, transpose. Reorder(or transpose) dim may bring data +// dependency. +// b. elemenwise. Those operations can be fused into a common for loop. +// c. broadcast. Need to analysis broadcast dim and the data +// dependency. +// d. reduction. Need to analysis broadcast dim and the +// data dependency. +// Same group operations have no data dependencies. They can be fused into a +// common for loop body. + +// Using queue to store the operation order. In order to ensure that +// subsequent moves to the operation will not cause semantic changes. +class VectorFusionStrategy { +private: + llvm::SmallVector, 8> opGroups; + // query current operation in which group, return group index + llvm::DenseMap opGroupIndexMap; + // can fused into prev operation which axis position + llvm::DenseMap opAnchorPos; + + llvm::SmallVector, 8> ignoreInitOperations; + + func::FuncOp func; + +public: + llvm::SmallVector, 8> &getOpGroups() { + return opGroups; + } + llvm::DenseMap &getOpGroupIndexMap() { + return opGroupIndexMap; + } + + func::FuncOp getFunc() { return func; } + llvm::SmallVector, 8> getIgnoreInitOperations() { + return ignoreInitOperations; + } + + VectorFusionStrategy() = default; + VectorFusionStrategy(func::FuncOp func) : func(func) {} + + void classifyOperations(); + + // run the vector fusion strategy + void run(); +}; + +enum CanonicalizerKind { OperationsGroup, Operations }; + +class MultiReductionCanonicalizer { +private: + llvm::SmallVector candidateRdOps; + llvm::SmallVector reductionAxis, parallelAxis; + bool haslastDimReduction = false; + bool isStandaloneOp = false; + int64_t typeRank = -1; + +public: + MultiReductionCanonicalizer( + const llvm::SmallVector &candidateRdOps) + : candidateRdOps(candidateRdOps) { + assert(candidateRdOps.size() > 1); + isStandaloneOp = candidateRdOps.size() == 1; + prepareReductionInfo(); + }; + int64_t getTypeRank(); + llvm::SmallVector &getCandidateOps(); + void getReductionAxisAndParallelAxis(); + bool hasLastDimReduction(); + bool getIsStandAloneOp() { return isStandaloneOp; } + void initReductionAxis(); + void initParallelAxis(); + llvm::SmallVector &getReductionAxis() { return reductionAxis; }; + llvm::SmallVector &getParallelAxis() { return parallelAxis; }; + void prepareReductionInfo(); +}; + +class CanonicalizerCommonUsedData { +private: + VectorFusionStrategy fusionStrategy; + // analysis the operation's operands and results + llvm::SmallVector, 8> groupOpResults, groupOpIterArgs; + + // store read and write operations permutation maps in order to convenient + // to replace loop induction var + llvm::DenseMap opPermuationMap; + llvm::SmallVector multiRdCanonicalizer; + +public: + CanonicalizerCommonUsedData() = default; + CanonicalizerCommonUsedData(VectorFusionStrategy &fusionStrategy) + : fusionStrategy(fusionStrategy){}; + + CanonicalizerCommonUsedData( + VectorFusionStrategy &fusionStrategy, + llvm::SmallVector, 8> &groupOpResults, + llvm::SmallVector, 8> &groupOpIterArgs, + llvm::DenseMap &opPermuationMap) + : fusionStrategy(fusionStrategy), groupOpResults(groupOpResults), + groupOpIterArgs(groupOpIterArgs), opPermuationMap(opPermuationMap) {} + + // set methods + void setFuseStrategy(VectorFusionStrategy &strategy) { + fusionStrategy = strategy; + auto opGroups = fusionStrategy.getOpGroups(); + if (opGroups.size() != groupOpResults.size() || + opGroups.size() != groupOpIterArgs.size()) { + groupOpResults.clear(); + groupOpIterArgs.clear(); + for (size_t i = 0; i < opGroups.size(); i++) { + groupOpResults.emplace_back(llvm::SetVector()); + groupOpIterArgs.emplace_back(llvm::SetVector()); + } + } + } + void + setGroupOpResults(llvm::SmallVector, 8> &results) { + groupOpResults = results; + } + void + setGroupOpIterArgs(llvm::SmallVector, 8> &iterArgs) { + groupOpIterArgs = iterArgs; + } + void setPermutationMap(llvm::DenseMap &map) { + opPermuationMap = map; + } + + // get methods + VectorFusionStrategy &getFusionStrategy() { return fusionStrategy; } + + llvm::SmallVector, 8> &getGroupOpResults() { + return groupOpResults; + } + + llvm::SmallVector, 8> &getGroupOpIterArgs() { + return groupOpIterArgs; + } + + llvm::DenseMap &getOpPermuationMap() { + return opPermuationMap; + } + llvm::SmallVector &getMultiRdCanonicalizer() { + return multiRdCanonicalizer; + } +}; + +class CanonicalizerVectorOperation { +private: + func::FuncOp func; + IRRewriter rewriter; + CanonicalizerKind kind; + CanonicalizerCommonUsedData commonUsedData; + +public: + CanonicalizerVectorOperation( + func::FuncOp func, + CanonicalizerKind kind = CanonicalizerKind::OperationsGroup) + : func(func), rewriter(func), kind(kind) { + // vector operation fusion + if (kind == CanonicalizerKind::OperationsGroup) { + auto fusionStrategy = VectorFusionStrategy(func); + fusionStrategy.run(); + commonUsedData.setFuseStrategy(fusionStrategy); + } + } + + // get functions + func::FuncOp &getFunc() { return func; }; + IRRewriter &getIRWewriter() { return rewriter; } + CanonicalizerCommonUsedData &getCommonUsedData() { return commonUsedData; } + + void generateGroupOpVectorizedIR(const int idx); + + void analysisGroupOperaionOperandsResults(); + + void analysisGroupOperationResults(); + + LogicalResult canonicalizeReductionOperation(); + LogicalResult canonicalizeTransposeOperation(vector::TransposeOp &transposeOp, + IRRewriter &rewriter); + void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId); + + // special operation methods + scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); + void getCandidateSpecialOps(); + void canonicalizeSpecialOperation(); + scf::ForOp parallelAxisGenerateForLoop( + const int groupIdx, const int parallelIdx, ValueRange &initArgs, + llvm::SmallVector &inductionVars, Value &originalWriteResult); + scf::ForOp + reductionAxisGenerateForLoop(const int groupIdx, const size_t reductionIdx, + ValueRange &initArgs, + llvm::SmallVector &inductionVars); + + void run(); +}; +} // namespace +} // namespace gc +} // namespace mlir +#endif \ No newline at end of file diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 4a11fb65b..ba4489476 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -6,32 +6,12 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "gc/Transforms/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "gc/Transforms/TilingVector.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/Passes.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/CSE.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" -#include "llvm/Support/Casting.h" -#include -#include -#include -#include -#include -#include +#include "mlir/IR/Builders.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" namespace mlir { namespace gc { @@ -70,7 +50,13 @@ bool is_innermost_operation(Operation *op) { } int generateValidSteps(int steps, VectorType type) { - return type.getShape().back() >= steps ? steps : 1; + return type.getShape().back() >= steps ? steps > 16 ? 16 : steps : 1; +} + +// expr equals `vector rank` - 1 +bool isLastDim(const AffineExpr &expr, const size_t rank) { + return mlir::isa(expr) && + mlir::dyn_cast(expr).getPosition() == rank - 1; } // Get the maximum number of current data types that a register can hold @@ -109,6 +95,10 @@ mlir::FailureOr getOperationVectorType(Operation *op) { }) .Case( [&](arith::ConstantOp constantOp) { return failure(); }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + return multiReductionOp.getSourceVectorType(); + }) .Default([&](Operation *op) -> mlir::FailureOr { if (!op->getResults().empty()) { auto t = mlir::dyn_cast(op->getResultTypes().front()); @@ -134,6 +124,275 @@ VectorType getVectorzedType(Operation *op) { return VectorType::get({loop_step}, vectorizedType.getElementType()); } +union Float32Bits { + uint32_t u; + float f; +}; + +const uint32_t kF32MantiBits = 23; +const uint32_t kF32HalfMantiBitDiff = 13; +const uint32_t kF32HalfBitDiff = 16; +const Float32Bits kF32Magic = {113 << kF32MantiBits}; +const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; +const uint32_t kF32BfMantiBitDiff = 16; + +// Constructs the 16 bit representation for a half precision value from a float +// value. This implementation is adapted from Eigen. +uint16_t float2half(float floatValue) { + const Float32Bits inf = {255 << kF32MantiBits}; + const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; + const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) + << kF32MantiBits}; + uint32_t signMask = 0x80000000u; + uint16_t halfValue = static_cast(0x0u); + Float32Bits f; + f.f = floatValue; + uint32_t sign = f.u & signMask; + f.u ^= sign; + + if (f.u >= f16max.u) { + const uint32_t halfQnan = 0x7e00; + const uint32_t halfInf = 0x7c00; + // Inf or NaN (all exponent bits set). + halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf + } else { + // (De)normalized number or zero. + if (f.u < kF32Magic.u) { + // The resulting FP16 is subnormal or zero. + // + // Use a magic value to align our 10 mantissa bits at the bottom of the + // float. As long as FP addition is round-to-nearest-even this works. + f.f += denormMagic.f; + + halfValue = static_cast(f.u - denormMagic.u); + } else { + uint32_t mantOdd = + (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. + + // Update exponent, rounding bias part 1. The following expressions are + // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + + // 0xfff`, but without arithmetic overflow. + f.u += 0xc8000fffU; + // Rounding bias part 2. + f.u += mantOdd; + halfValue = static_cast(f.u >> kF32HalfMantiBitDiff); + } + } + + halfValue |= static_cast(sign >> kF32HalfBitDiff); + return halfValue; +} + +// Converts the 16 bit representation of a half precision value to a float +// value. This implementation is adapted from Eigen. +float half2float(uint16_t halfValue) { + const uint32_t shiftedExp = + 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. + + // Initialize the float representation with the exponent/mantissa bits. + Float32Bits f = { + static_cast((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; + const uint32_t exp = shiftedExp & f.u; + f.u += kF32HalfExpAdjust; // Adjust the exponent + + // Handle exponent special cases. + if (exp == shiftedExp) { + // Inf/NaN + f.u += kF32HalfExpAdjust; + } else if (exp == 0) { + // Zero/Denormal? + f.u += 1 << kF32MantiBits; + f.f -= kF32Magic.f; + } + + f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. + return f.f; +} + +// Constructs the 16 bit representation for a bfloat value from a float value. +// This implementation is adapted from Eigen. +uint16_t float2bfloat(float floatValue) { + if (std::isnan(floatValue)) + return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; + + Float32Bits floatBits; + floatBits.f = floatValue; + uint16_t bfloatBits; + + // Least significant bit of resulting bfloat. + uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; + uint32_t roundingBias = 0x7fff + lsb; + floatBits.u += roundingBias; + bfloatBits = static_cast(floatBits.u >> kF32BfMantiBitDiff); + return bfloatBits; +} + +// Converts the 16 bit representation of a bfloat value to a float value. This +// implementation is adapted from Eigen. +float bfloat2float(uint16_t bfloatBits) { + Float32Bits floatBits; + floatBits.u = static_cast(bfloatBits) << kF32BfMantiBitDiff; + return floatBits.f; +} + +bool isReadWriteOnLastDim(Operation *op) { + if (mlir::isa(op) || + mlir::isa(op)) { + auto permutationMap = + mlir::dyn_cast(op) + ? mlir::dyn_cast(op).getPermutationMap() + : mlir::dyn_cast(op).getPermutationMap(); + auto rank = + mlir::dyn_cast(op) + ? mlir::dyn_cast(op->getOperand(0).getType()).getRank() + : mlir::dyn_cast(op->getOperand(0).getType()).getRank(); + auto dimExpr = permutationMap.getResults(); + bool find = false; + for (auto &expr : dimExpr) { + if (isLastDim(expr, rank)) { + find = true; + } + } + return find; + } + LDBG("The operation is not a read or write operation." << *op << "\n"); + return false; +} + +std::variant numeric_zero(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return 0.f; + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(0.f)); + } else if (t1.isF16()) { + return half2float(float2half(0.f)); + } else if (t1.isSignedInteger(8)) { + return int64_t(0); + } else if (t1.isSignedInteger(32)) { + return int64_t(0); + } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { + return int64_t(0); + } else { + LDBG("Unsupported data type: " << t1 << "\n"); + assert(0 && "unsupported data type"); + return (int64_t)0; + } +} + +std::variant numeric_one(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return 1.f; + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(1.f)); + } else if (t1.isF16()) { + return half2float(float2half(1.f)); + } else if (t1.isSignedInteger(8)) { + return int64_t(1); + } else if (t1.isSignedInteger(32)) { + return int64_t(1); + } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { + return int64_t(1); + } else { + LDBG("Unsupported data type: " << t1 << "\n"); + assert(0 && "unsupported data type"); + return (int64_t)1; + } +} + +std::variant numeric_limits_minimum(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return -std::numeric_limits::infinity(); + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(-std::numeric_limits::infinity())); + } else if (t1.isF16()) { + return (float)-65504; + } else if (t1.isSignedInteger(8)) { + return int64_t(-128); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::min()); + } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { + return int64_t(0); + } else { + LDBG("Unsupported data type: " << t1 << "\n"); + assert(0 && "unsupported data type"); + return (int64_t)0; + } +} + +std::variant numericLimitsMaximum(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return std::numeric_limits::infinity(); + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(std::numeric_limits::infinity())); + } else if (t1.isF16()) { + return (float)65504; + } else if (t1.isSignedInteger(8)) { + return int64_t(127); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::max()); + } else if (t1.isSignlessInteger(8)) { + return int64_t(255); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::max()); + } else { + LDBG("Unsupported data type: " << t1 << "\n"); + assert(0 && "unsupported data type"); + return (int64_t)0; + } +} + +template +T getInitValForReduce(vector::CombiningKind kind, Type t) { + T result; + Type t1 = getElementTypeOrSelf(t); + + switch (kind) { + case vector::CombiningKind::ADD: + if (t1.isIntOrIndex()) + result = 0; + else if (llvm::isa(t1)) + result = 0.0f; + else + llvm_unreachable("invalid value types for ADD reduction"); + break; + case vector::CombiningKind::MAXNUMF: + case vector::CombiningKind::MAXIMUMF: + assert(llvm::isa(t1) && "expected float values"); + result = std::get(numeric_limits_minimum(t)); + break; + case vector::CombiningKind::MINNUMF: + case vector::CombiningKind::MINIMUMF: + assert(llvm::isa(t1) && "expected float values"); + result = std::get(numericLimitsMaximum(t)); + break; + case vector::CombiningKind::MAXSI: + case vector::CombiningKind::MAXUI: + assert(t1.isIntOrIndex() && "expected int values"); + result = std::get(numeric_limits_minimum(t)); + break; + case vector::CombiningKind::MINSI: + case vector::CombiningKind::MINUI: + assert(t1.isIntOrIndex() && "expected int values"); + result = std::get(numericLimitsMaximum(t)); + break; + case vector::CombiningKind::MUL: + if (t1.isIntOrIndex()) + result = 1; + else if (llvm::isa(t1)) + result = 1.f; + else + llvm_unreachable("invalid value types for MUL reduction"); + break; + default: + llvm_unreachable("unsupported reduction kind"); + }; + return result; +} + // Since we rewrite transfer_read and transfer_write, the `permutationmap` must // be changed. void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, @@ -184,25 +443,67 @@ Type getScalarType(Operation *op) { } Operation *createTensorEmptyBefore(Operation *op) { - auto rtType = op->getResultTypes()[0].dyn_cast(); + auto rtType = mlir::dyn_cast(op->getResultTypes()[0]); IRRewriter reWriter(op); SmallVector shapes; SmallVector dynDims; for (unsigned i = 0; i < rtType.getRank(); i++) { shapes.push_back(rtType.getDimSize(i)); - if (rtType.isDynamicDim(i)) + if (rtType.isDynamicDim(i)) { dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), op->getResult(0), i)); + } } return reWriter.create(op->getLoc(), rtType.getShape(), rtType.getElementType(), dynDims); } +Value getOperationResultTensor(Operation *op) { + auto result = op->getResults()[0]; + for (auto x : result.getUsers()) { + if (mlir::isa(x)) { + return x->getOperand(1); + } + } + LDBG("Result not write back to tensor."); + + return createTensorEmptyBefore(op)->getResults()[0]; +} + +Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { + auto rtType = mlir::dyn_cast(op->getResultTypes()[0]); + auto rank = rtType.getRank(); + auto dstType = mlir::dyn_cast(dest.getType()); + IRRewriter reWriter(op); + + auto zero = + reWriter.create(reWriter.getUnknownLoc(), 0); + + reWriter.setInsertionPointAfter(op); + SmallVector inBoundsVal(rank, true); + + SmallVector shapes; + SmallVector dynDims; + for (unsigned i = 0; i < rtType.getRank(); i++) { + shapes.push_back(rtType.getDimSize(i)); + if (rtType.isDynamicDim(i)) { + dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), + op->getResult(0), i)); + } + } + return reWriter.create( + reWriter.getUnknownLoc(), + /*vector=*/op->getResult(0), + /*source=*/dest, + /*indices=*/SmallVector(dstType.getRank(), zero), + /*inBounds=*/inBoundsVal); +} + Operation * createTransferReadOpBefore(Operation *op, const Value &operand, vector::TransferReadOp *srcReadOp = nullptr) { - auto operandType = operand.getType().dyn_cast(); + auto operandType = mlir::dyn_cast(operand.getType()); IRRewriter rewriter(op); auto zero = @@ -212,7 +513,7 @@ createTransferReadOpBefore(Operation *op, const Value &operand, rewriter.getZeroAttr(operandType.getElementType())); if (srcReadOp) { - auto resultType = srcReadOp->getType().dyn_cast(); + auto resultType = mlir::dyn_cast(srcReadOp->getType()); SmallVector inBoundsVal(resultType.getRank(), true); auto srcReadOpAffineMap = srcReadOp->getPermutationMap(); // result of read operation should be same as operand @@ -240,6 +541,26 @@ createTransferReadOpBefore(Operation *op, const Value &operand, } } +// canonicalizing operation as tensor empty and transfer write the operation +// result into the empty tensor +[[nodiscard]] std::pair +canonicalizeSourceOperation(Operation *op) { + // auto emtpyOp = createTensorEmptyBefore(op); + auto resultTensor = getOperationResultTensor(op); + auto writeOp = createTransferWriteOpAfter(op, resultTensor); + return std::make_pair(resultTensor, writeOp->getResults()[0]); +} + +[[nodiscard]] Value +canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, + size_t operandIdx, + vector::TransferReadOp *srcReadOp = nullptr) { + // transfer_read operation + auto readOp = createTransferReadOpBefore(op, transferReadOperand, srcReadOp); + op->setOperand(operandIdx, readOp->getResults()[0]); + return readOp->getResults()[0]; +} + // __________________________________ // Speical operations canonicalization // __________________________________ @@ -248,6 +569,82 @@ createTransferReadOpBefore(Operation *op, const Value &operand, // MultiReduce Operation //===----------------------------------------------------------------------===// +void getOpSourceOps(Operation *op, llvm::DenseSet &srcOps) { + llvm::SmallVector srcOperands = op->getOperands(); + std::deque srcOperandsQueue(srcOperands.begin(), srcOperands.end()); + llvm::DenseSet visited; + visited.insert(op); + while (!srcOperandsQueue.empty()) { + auto accOperand = srcOperandsQueue.front(); + srcOperandsQueue.pop_front(); + auto accOperandOp = accOperand.getDefiningOp(); + if (visited.count(accOperandOp)) { + continue; + } + visited.insert(accOperandOp); + srcOps.insert(accOperandOp); + auto accOperandOperands = accOperandOp->getOperands(); + srcOperandsQueue.insert(srcOperandsQueue.end(), accOperandOperands.begin(), + accOperandOperands.end()); + } +} + +bool isSrcRelated(const llvm::DenseSet &srcOps, Operation *op) { + return srcOps.count(op); +} + +void getPrevOps(std::queue &prevOps, + std::queue &opQueue, Operation *currentOp) { + while (!opQueue.empty() && currentOp != opQueue.front()) { + prevOps.push(opQueue.front()); + opQueue.pop(); + } +} + +void getPostOps(std::queue &postOps, + std::queue &opQueue, Operation *currentOp) { + // pop multireduction op + assert(currentOp == opQueue.front() && "Current operation is not the front " + "operation of the operation queue."); + opQueue.pop(); + while (!opQueue.empty()) { + postOps.push(opQueue.front()); + opQueue.pop(); + } +} + +void getReductionInitAttr(vector::MultiDimReductionOp &multiReductionOp, + Attribute &initValueAttr) { + auto vecType = multiReductionOp.getSourceVectorType(); + auto resultElementType = vecType.getElementType(); + if (isa(resultElementType)) { + initValueAttr = FloatAttr::get( + resultElementType, + getInitValForReduce(multiReductionOp.getKind(), vecType)); + } else { + initValueAttr = IntegerAttr::get( + resultElementType, + getInitValForReduce(multiReductionOp.getKind(), vecType)); + } +} + +void classifySourceRelatedOps(std::queue &accRelatedOps, + std::queue &sourceRelatedOps, + Operation *srcOp, + std::queue &prevOps) { + llvm::DenseSet srcOps; + getOpSourceOps(srcOp, srcOps); + while (!prevOps.empty()) { + auto op = prevOps.front(); + prevOps.pop(); + if (isSrcRelated(srcOps, op)) { + sourceRelatedOps.push(op); + } else { + accRelatedOps.push(op); + } + } +} + enum class MultiReduceOpAxisKind { Reduction, Parallel }; void updateReduceReadWriteOperationOperand( const llvm::SmallVector &inductionVars, @@ -263,7 +660,7 @@ void updateReduceReadWriteOperationOperand( } } -vector::TransferReadOp makeNewTransferReadOp( +vector::TransferReadOp cloneReductionTransferRead( Value &source, OpBuilder &b, IRMapping &readMap, const llvm::SmallVector ¶llelAxis, llvm::SmallVector &inductionVars, bool lastDimReduction, @@ -286,7 +683,7 @@ vector::TransferReadOp makeNewTransferReadOp( newReadOp->getResult(0).setType(newOperandType); setOpVectorizationPermutationMap( newReadOp, b, - newReadOp.getSource().getType().dyn_cast(), + mlir::dyn_cast(newReadOp.getSource().getType()), newReadOp.getPermutationMap()); rewriter.replaceOp(readOp, newReadOp); @@ -305,47 +702,76 @@ makeNewTransferWriteOp(Value source, IRMapping &writeMap, OpBuilder &b, MultiReduceOpAxisKind::Parallel); setOpVectorizationPermutationMap( newWriteOp, b, - newWriteOp->getResult(0).getType().dyn_cast(), + mlir::dyn_cast(newWriteOp->getResult(0).getType()), newWriteOp.getPermutationMap()); bodyRewriter.replaceOp(writeOp, newWriteOp); return newWriteOp; } -scf::ForOp reductionAxisGenerateForLoop( - OpBuilder &opBuilder, vector::MultiDimReductionOp &multiReductionOp, - const llvm::SmallVector &reductionAxis, - const size_t reductionIdx, const VectorType &vectorType, - llvm::SmallVector &inductionVars, const ValueRange &iterArgs, - bool lastDimReduction, Location &loc, const int loopStep) { - - auto zero = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr(opBuilder.getIndexType(), 0)); - auto forSteps = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr( - opBuilder.getIndexType(), - (reductionIdx == reductionAxis.size() - 1 && lastDimReduction) - ? loopStep - : 1)); - auto numIter = opBuilder.create( +Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x) { + return opBuilder.create( loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr(opBuilder.getIndexType(), - vectorType.getShape()[reductionIdx])); + opBuilder.getIntegerAttr(opBuilder.getIndexType(), x)); +} + +void moveOperationsToCurrentForBody( + std::queue opQueue, OpBuilder &b, ValueRange loopState, + const llvm::DenseMap &operandIdxMap, + const llvm::SmallVector &inductionVars, + llvm::DenseMap &opPermuationMap) { + Operation *lastOperation = opQueue.front(); + while (!opQueue.empty()) { + auto x = opQueue.front(); + opQueue.pop(); + if (lastOperation == x) { + x->moveBefore(b.getBlock(), b.getBlock()->begin()); + } else { + x->moveAfter(lastOperation); + lastOperation = x; + } + // check operation type to set correct operand + checkAndSetOperand(x, loopState, operandIdxMap, inductionVars, + opPermuationMap); + } +} + +scf::ForOp CanonicalizerVectorOperation::reductionAxisGenerateForLoop( + const int groupIdx, const size_t reductionIdx, ValueRange &initArgs, + llvm::SmallVector &inductionVars) { + MultiReductionCanonicalizer rdCanonicalizer = + commonUsedData.getMultiRdCanonicalizer()[groupIdx]; + auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; + OpBuilder opBuilder(multireductionOp); + auto loc = multireductionOp->getLoc(); + auto &reductionAxis = rdCanonicalizer.getReductionAxis(); + auto lastDimReduction = rdCanonicalizer.hasLastDimReduction(); + auto vectorType = multireductionOp.getSourceVectorType(); + const int loopStep = getDataTypeMAXSIMDLength(vectorType); + auto isStandaloneOp = rdCanonicalizer.getIsStandAloneOp(); + + auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); + auto forSteps = makeIndexArithConstantOp( + opBuilder, loc, + (reductionIdx == reductionAxis.size() - 1 && lastDimReduction) ? loopStep + : 1); + auto numIter = makeIndexArithConstantOp(opBuilder, loc, + vectorType.getShape()[reductionIdx]); auto forOp = opBuilder.create( - loc, zero, numIter, forSteps, iterArgs, + loc, zero, numIter, forSteps, initArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { inductionVars.emplace_back(iv); if (reductionIdx == reductionAxis.size() - 1) { IRRewriter rewriter(b); IRMapping readMap; - Value reductionTarget = multiReductionOp->getOperand(0); + Value reductionTarget = multireductionOp.getSource(); llvm::SmallVector parallelAxis; - auto newReadOp = makeNewTransferReadOp( + auto newReadOp = cloneReductionTransferRead( reductionTarget, b, readMap, parallelAxis, inductionVars, lastDimReduction, MultiReduceOpAxisKind::Reduction); + if (isStandaloneOp) { + } // reduction or elementwise reduce // if (lastDimReduction) { // Operation *reductionOp = rewriter.create( @@ -354,15 +780,14 @@ scf::ForOp reductionAxisGenerateForLoop( // maybeYieldValue(b, loc, reductionOp->getResults()); // } else { auto reductionResult = - makeArithReduction(b, loc, multiReductionOp.getKind(), + makeArithReduction(b, loc, multireductionOp.getKind(), newReadOp->getResult(0), loopState.back()); maybeYieldValue(b, loc, reductionResult); // } } else { // outter loop - auto nxtFor = reductionAxisGenerateForLoop( - b, multiReductionOp, reductionAxis, reductionIdx + 1, vectorType, - inductionVars, loopState, lastDimReduction, loc, loopStep); + auto nxtFor = reductionAxisGenerateForLoop(groupIdx, reductionIdx + 1, + loopState, inductionVars); maybeYieldValue(b, loc, nxtFor->getResults()); } }); @@ -370,545 +795,404 @@ scf::ForOp reductionAxisGenerateForLoop( return forOp; } -scf::ForOp parallelAxisGenerateForLoop( - OpBuilder &opBuilder, vector::MultiDimReductionOp &multiReductionOp, - const llvm::SmallVector ¶llelAxis, const size_t parallelIdx, - const llvm::SmallVector &reductionAxis, - const size_t reductionIdx, const VectorType &vectorType, - llvm::SmallVector &inductionVars, const ValueRange &iterArgs, - Value &originalWriteResult, bool lastDimReduction, Location &loc, - const int loopStep) { - auto zero = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr(opBuilder.getIndexType(), 0)); - auto forSteps = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr(opBuilder.getIndexType(), 1)); - auto numIter = opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr( - opBuilder.getIndexType(), - vectorType.getShape()[parallelAxis[parallelIdx]])); +scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( + const int groupIdx, const int parallelIdx, ValueRange &initArgs, + llvm::SmallVector &inductionVars, Value &originalWriteResult) { + MultiReductionCanonicalizer rdCanonicalizer = + commonUsedData.getMultiRdCanonicalizer()[groupIdx]; + auto &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; + auto vectorType = multiReductionOp.getSourceVectorType(); + auto ¶llelAxis = rdCanonicalizer.getParallelAxis(); + auto isStandaloneOp = rdCanonicalizer.getIsStandAloneOp(); + auto lastDimReduction = rdCanonicalizer.hasLastDimReduction(); + OpBuilder opBuilder(multiReductionOp); + auto loc = multiReductionOp.getLoc(); + auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); + auto forSteps = makeIndexArithConstantOp(opBuilder, loc, 1); + + // last dim reduction need to a generate dim=16 loop + int dimSize = 0; + if (lastDimReduction && parallelIdx == parallelAxis.size() && + !isStandaloneOp) { + dimSize = 16; + } else { + dimSize = vectorType.getShape()[parallelAxis[parallelIdx]]; + } + auto numIter = makeIndexArithConstantOp(opBuilder, loc, dimSize); // Create a loop and move vectorized operation into loops. auto forOp = opBuilder.create( - loc, zero, numIter, forSteps, iterArgs, + loc, zero, numIter, forSteps, initArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { inductionVars.emplace_back(iv); + auto fusionStrategy = commonUsedData.getFusionStrategy(); + auto opIndexMap = fusionStrategy.getOpGroupIndexMap(); + + assert(opIndexMap.contains(multiReductionOp) && + " Must constains multireduction operation."); + + auto opIndex = opIndexMap[multiReductionOp]; + auto &opGroups = fusionStrategy.getOpGroups(); + auto &opPermuationMap = commonUsedData.getOpPermuationMap(); + auto opQueue = opGroups[opIndex]; + auto multiReductionAcc = multiReductionOp.getAcc(); + auto accType = mlir::dyn_cast(multiReductionAcc.getType()); if (parallelIdx == parallelAxis.size() - 1) { + // four kinds of group operations + // If fused a operation, it means multirection must just + // constains last dim to do the reduction. + // 1. just multireduction + // two cases: + // 1. constaints last dims + // for ... parallel axis: + // transfer_read from accSource tensor + // arith.constant : vector<16xf32> + // for ... 16: + // for ... reduction axis: + // add + // scalar = reduction vector + // scalar insert into accVector + // transfer_write accVector into emtpy tensor + // 2. not last dims + // for ... generate axis: + // transfer_read from accSource tensor + // transfer_read from source tensor + // accVector = add + // transfer_write accVector into emtpy tensor + // 2. prev-op + multireduction + // In this case, there will be no tensor.empty + transfer_read + // operation, but the multireduction should write in an empty + // tensor + // for ... parallel axis: + // accVector and related accVector operation should be here + // extract from accVector scalar + // airth.constant : vector<16xf32> + // for ... reduction axis: + // prevop source op + // add + // scalar = reduction vector + // scalar insert into accVector + // transfer_write accVector into empty tensor + // + // 3. post-op + multireduction + // for ... parallel axis: + // transferread from accSource tensor + // arith.constant : vector<16xf32> + // for ... reduction axis: + // add + // postOp + // post Op transferWrite emtpy tensor + // scalar = reduction vector + // scalar insert into accVector + // transfer_write accVector into emtpy tensor + // 4. prev-op + multireduction + post-op + // for ... parallel axis: + // accVector operation + // extract from accVector a scalar + // arith.constant : vector<16xf32> + // for ... reduction axis: + // prev-op source op and related source operation + // add + // postOp + // post Op transferWrite emtpy tensor + // scalar = reduction vector + // scalar insert into accVector + // transfer_write accVector into emtpy tensor + + if (isStandaloneOp) { + // read operation + IRMapping accReadMap; + auto accReadOp = multiReductionAcc.getDefiningOp(); + assert(mlir::isa(accReadOp)); + accReadMap.map(accReadOp->getOperand(0), loopState.back()); + + auto newAccReadOp = cloneReductionTransferRead( + multiReductionAcc, b, accReadMap, parallelAxis, inductionVars, + lastDimReduction, MultiReduceOpAxisKind::Parallel); + // constructe next for loop + Attribute initValueAttr; + getReductionInitAttr(multiReductionOp, initValueAttr); + + auto accVal = b.create( + loc, DenseElementsAttr::get(accType, {initValueAttr})); + + ValueRange newIterArgs(accVal); + auto nxtFor = reductionAxisGenerateForLoop(groupIdx, 0, newIterArgs, + inductionVars); + + // insert accumulate value to original vector + auto accRes = nxtFor->getResults()[0]; + + Operation *reductionOp = b.create( + loc, multiReductionOp.getKind(), accRes); + auto insertOp = + b.create(loc, reductionOp->getResult(0), + newAccReadOp->getResults()[0], 0); + + // write vector back to tensor + vector::TransferWriteOp accWriteOp = nullptr; + for (auto [idx, x] : llvm::enumerate( + multiReductionOp->getResults()[0].getUsers())) { + if (idx == 0 && mlir::isa(x)) { + accWriteOp = mlir::dyn_cast(x); + break; + } + } + assert(accWriteOp && + " Not transfer_write operation. Current multireduction " + "operation may have wrong analysis IR."); + IRMapping accWriteindiceMap; + accWriteindiceMap.map(accWriteOp.getOperand(0), + insertOp->getResults()[0]); + auto writeResult = accWriteOp->getResults()[0]; + auto newAccWriteOp = makeNewTransferWriteOp( + writeResult, accWriteindiceMap, b, parallelAxis, inductionVars); + originalWriteResult = newAccWriteOp->getResult(0); + + maybeYieldValue(b, loc, newAccWriteOp->getResults()); + } + // else { + // auto prevOp = opQueue.front(); + // auto postOp = opQueue.back(); + + // if (mlir::isa(prevOp)) { + + // } else { + // if (mlir::isa(postOp)) { + // // prevOp + reduction op + // } else { + // // prevOp + reduction op + postOp + // // reduction op + postOp + + // getPrevOps(prevOps, opQueue, multiReductionOp); + // getPostOps(postOps, opQueue, multiReductionOp); + // // analysis acc related operation + // std::queue accRelatedOps, sourceRelatedOps; + // llvm::SmallVector iterArgsArray; + + // // prevOp need to classify + // classifySourceRelatedOps( + // accRelatedOps, sourceRelatedOps, + // multiReductionOp.getSource().getDefiningOp(), prevOps); + // rewriteOperationAsVectorize(prevOps, opIndexMap, b, + // opPermuationMap); + // moveOperationsToCurrentForBody(accRelatedOps, b, loopState, + // operandIdxMap, inductionVars, + // opPermuationMap); + // iterArgsArray.emplace_back(multiReductionAcc); + // ValueRange reductionAxisArgs(iterArgsArray); + // auto nxtFor = parallelAxisGenerateForLoop( + // b, multiReductionOp, parallelAxis, parallelIdx + 1, + // reductionAxis, reductionIdx, vectorType, inductionVars, + // loopState, operandIdxMap, originalWriteResult, + // lastDimReduction, loc, loopStep, canonicalizer, + // isStandaloneOp); + + // // prepare iterArgs + // } + // } + // } - // read operation - IRMapping accReadMap; - auto multiReductionAcc = multiReductionOp.getAcc(); - auto accReadOp = multiReductionAcc.getDefiningOp(); - accReadMap.map(accReadOp->getOperand(0), loopState.back()); - - auto newAccReadOp = makeNewTransferReadOp( - multiReductionAcc, b, accReadMap, parallelAxis, inductionVars, - lastDimReduction, MultiReduceOpAxisKind::Parallel); - auto resultElementType = vectorType.getElementType(); - // constructe next for loop - // auto accVal = b.create( - // loc, opBuilder.getZeroAttr(vectorType.getElementType())); - Attribute initValueAttr; - if (isa(resultElementType)) { - initValueAttr = FloatAttr::get(resultElementType, 0.0); + } else { + if (parallelIdx == parallelAxis.size() && !isStandaloneOp && + lastDimReduction) { + + Attribute initValueAttr; + getReductionInitAttr(multiReductionOp, initValueAttr); + + auto accVal = b.create( + loc, DenseElementsAttr::get(accType, {initValueAttr})); + ValueRange newIterArgs(accVal); + auto nxtFor = reductionAxisGenerateForLoop(groupIdx, 0, newIterArgs, + inductionVars); + // insert accumulate value to original vector + auto accRes = nxtFor->getResults()[0]; + + Operation *reductionOp = b.create( + loc, multiReductionOp.getKind(), accRes); + auto insertOp = b.create( + loc, reductionOp->getResult(0), initArgs[0], iv); + maybeYieldValue(b, loc, insertOp->getResults()); } else { - initValueAttr = IntegerAttr::get(resultElementType, 0); - } - auto accVal = b.create( - loc, DenseElementsAttr::get(getVectorzedType(multiReductionOp), - {initValueAttr})); - - ValueRange newIterArgs(accVal); - auto nxtFor = reductionAxisGenerateForLoop( - b, multiReductionOp, reductionAxis, reductionIdx, vectorType, - inductionVars, newIterArgs, lastDimReduction, loc, loopStep); - - // insert accumulate value to original vector - auto accRes = nxtFor->getResults()[0]; - - Operation *reductionOp = b.create( - loc, multiReductionOp.getKind(), accRes); - auto insertOp = b.create( - loc, reductionOp->getResult(0), newAccReadOp->getResults()[0], 0); - - // write vector back to tensor - vector::TransferWriteOp accWriteOp = nullptr; - for (auto [idx, x] : - llvm::enumerate(multiReductionOp->getResults()[0].getUsers())) { - if (idx == 0 && mlir::isa(x)) { - accWriteOp = mlir::dyn_cast(x); - break; - } + auto nxtFor = parallelAxisGenerateForLoop(groupIdx, parallelIdx + 1, + loopState, inductionVars, + originalWriteResult); + maybeYieldValue(b, loc, nxtFor->getResults()); } - assert(accWriteOp && - " Not transfer_write operation. Current multireduction " - "operation may have wrong analysis IR."); - IRMapping accWriteindiceMap; - accWriteindiceMap.map(accWriteOp.getOperand(0), - insertOp->getResults()[0]); - auto writeResult = accWriteOp->getResults()[0]; - auto newAccWriteOp = makeNewTransferWriteOp( - writeResult, accWriteindiceMap, b, parallelAxis, inductionVars); - originalWriteResult = newAccWriteOp->getResult(0); - - maybeYieldValue(b, loc, newAccWriteOp->getResults()); - } else { - auto nxtFor = parallelAxisGenerateForLoop( - b, multiReductionOp, parallelAxis, parallelIdx + 1, reductionAxis, - reductionIdx, vectorType, inductionVars, loopState, - originalWriteResult, lastDimReduction, loc, loopStep); - maybeYieldValue(b, loc, nxtFor->getResults()); } }); return forOp; } -scf::ForOp generateMultiReductionForLoop( - OpBuilder &opBuilder, vector::MultiDimReductionOp &multiReductionOp, - const llvm::SmallVector ¶llelAxis, const size_t parallelIdx, - const llvm::SmallVector &reductionAxis, - const size_t reductionIdx, const VectorType &vectorType, - llvm::SmallVector &inductionVars, const ValueRange &iterArgs, - Value &originalWriteResult, bool lastDimReduction) { - const int loopStep = getDataTypeMAXSIMDLength(vectorType); - auto loc = multiReductionOp->getLoc(); +scf::ForOp CanonicalizerVectorOperation::generateMultiReductionForLoop( + const size_t grpIdx) { + auto &grpResults = commonUsedData.getGroupOpResults()[grpIdx]; + llvm::SmallVector forLoopArgs(grpResults.begin(), grpResults.end()); + llvm::SmallVector inductionVars; + ValueRange initArgs(forLoopArgs); + Value originalWriteResult; + + scf::ForOp forOp = parallelAxisGenerateForLoop(0, 0, initArgs, inductionVars, + originalWriteResult); + auto replaceIfFn = [&](OpOperand &use) { + return use.getOwner()->getBlock() != + originalWriteResult.getDefiningOp()->getBlock(); + }; + rewriter.replaceOpUsesWithIf(originalWriteResult.getDefiningOp(), + forOp->getResults()[0], replaceIfFn); - scf::ForOp forOp = parallelAxisGenerateForLoop( - opBuilder, multiReductionOp, parallelAxis, parallelIdx, reductionAxis, - reductionIdx, vectorType, inductionVars, iterArgs, originalWriteResult, - lastDimReduction, loc, loopStep); + rewriter.replaceOp( + commonUsedData.getMultiRdCanonicalizer()[grpIdx].getCandidateOps()[0], + forOp); return forOp; } -// mlir::FailureOr generateTransposeForLoop( -// OpBuilder &opBuilder, vector::TransposeOp &transposeOp, -// const llvm::SmallVector ¶llelAxis, const size_t -// parallelIdx, const llvm::SmallVector &reductionAxis, const -// size_t reductionIdx, const VectorType &vectorType, -// llvm::SmallVector &inductionVars, const ValueRange &iterArgs, -// Value &originalWriteResult, bool lastDimReduction) { -// const int loop_step = getDataTypeMAXSIMDLength(vectorType); -// auto loc = transposeOp->getLoc(); -// auto zero = opBuilder.create( -// loc, opBuilder.getIndexType(), -// opBuilder.getIntegerAttr(opBuilder.getIndexType(), 0)); - -// scf::ForOp forOp = nullptr; -// // parallel axis -// if (parallelIdx < parallelAxis.size()) { -// auto forSteps = opBuilder.create( -// loc, opBuilder.getIndexType(), -// opBuilder.getIntegerAttr(opBuilder.getIndexType(), 1)); -// auto numIter = opBuilder.create( -// loc, opBuilder.getIndexType(), -// opBuilder.getIntegerAttr( -// opBuilder.getIndexType(), -// vectorType.getShape()[parallelAxis[parallelIdx]])); -// // Create a loop and move vectorized operation into loops. -// forOp = opBuilder.create( -// loc, zero, numIter, forSteps, iterArgs, -// [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { -// inductionVars.emplace_back(iv); -// if (parallelIdx == parallelAxis.size() - 1) { -// // move original transfer_read operation into parallel axis loop -// // body -// // get read operation -// Value multiReductionAcc = multiReductionOp.getAcc(); -// auto accReadOp = -// multiReductionAcc.getDefiningOp(); -// assert(accReadOp && -// " Not transfer_read operation. Current multireduction " -// "operation may have wrong analysis IR."); -// // get write operation -// vector::TransferWriteOp accWriteOp = nullptr; -// for (auto [idx, x] : llvm::enumerate( -// multiReductionOp->getResults()[0].getUsers())) { -// if (idx == 0 && mlir::isa(x)) { -// accWriteOp = mlir::dyn_cast(x); -// break; -// } -// } -// assert(accWriteOp); -// IRMapping accReadindiceMap; - -// IRRewriter bodyRewriter(b); -// auto newAccReadOp = mlir::dyn_cast( -// b.clone(*accReadOp, accReadindiceMap)); -// bodyRewriter.replaceOp(accReadOp, newAccReadOp); -// int offset = 1; -// for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { -// if (idx >= parallelAxis.size()) { -// break; -// } -// newAccReadOp->setOperand(idx + offset, inductionVar); -// } -// auto newOperandType = getScalarType(newAccReadOp); -// newAccReadOp->getResult(0).setType(newOperandType); -// setOpVectorizationPermutationMap( -// newAccReadOp, b, -// newAccReadOp.getSource().getType().dyn_cast(), -// newAccReadOp.getPermutationMap()); -// // constructe next for loop -// auto accVal = b.create( -// loc, opBuilder.getZeroAttr(vectorType.getElementType())); -// ValueRange newIterArgs(accVal); -// auto nxtFor = generateMultiReductionForLoop( -// b, multiReductionOp, parallelAxis, parallelIdx + 1, -// reductionAxis, reductionIdx, vectorType, inductionVars, -// newIterArgs, originalWriteResult, lastDimReduction); - -// // move original transfer_write into loop -// auto accRes = nxtFor.value()->getResults()[0]; - -// // replace the vector as the loop return vector value -// llvm::SmallVector insertPos; -// auto insertOp = b.create( -// loc, accRes, newAccReadOp->getResult(0), 0); -// IRMapping accWriteindiceMap; -// accWriteindiceMap.map(accWriteOp.getOperand(0), -// insertOp->getResults()[0]); -// auto newAccWriteOp = mlir::dyn_cast( -// b.clone(*accWriteOp, accWriteindiceMap)); -// offset = 2; -// for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { -// if (idx >= parallelAxis.size()) { -// break; -// } -// newAccWriteOp->setOperand(idx + offset, inductionVar); -// } -// setOpVectorizationPermutationMap(newAccWriteOp, b, -// newAccWriteOp->getResult(0) -// .getType() -// .dyn_cast(), -// newAccWriteOp.getPermutationMap()); -// bodyRewriter.replaceOp(accWriteOp, newAccWriteOp); -// originalWriteResult = newAccWriteOp->getResult(0); -// maybeYieldValue(b, loc, newAccWriteOp->getResults()); -// } else { -// auto nxtFor = generateMultiReductionForLoop( -// b, multiReductionOp, parallelAxis, parallelIdx + 1, -// reductionAxis, reductionIdx, vectorType, inductionVars, -// iterArgs, originalWriteResult, lastDimReduction); -// maybeYieldValue(b, loc, nxtFor.value()->getResults()); -// } -// }); - -// } else { - -// auto forSteps = opBuilder.create( -// loc, opBuilder.getIndexType(), -// opBuilder.getIntegerAttr( -// opBuilder.getIndexType(), -// reductionIdx == reductionAxis.size() - 1 ? loop_step : 1)); -// auto numIter = opBuilder.create( -// loc, opBuilder.getIndexType(), -// opBuilder.getIntegerAttr(opBuilder.getIndexType(), -// vectorType.getShape()[reductionIdx])); -// forOp = opBuilder.create( -// loc, zero, numIter, forSteps, iterArgs, -// [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { -// inductionVars.emplace_back(iv); - -// if (reductionIdx == reductionAxis.size() - 1) { -// auto source = multiReductionOp->getOperand(0); - -// auto readOp = -// mlir::dyn_cast(source.getDefiningOp()); -// assert(readOp); -// IRMapping indiceMap; -// IRRewriter rewriter(b); -// auto clonedOp = b.clone(*readOp, indiceMap); -// int offset = 1; -// auto newReadOp = -// mlir::dyn_cast(clonedOp); - -// for (auto [idx, inductionVar] : enumerate(inductionVars)) { -// newReadOp->setOperand(idx + offset, inductionVar); -// } - -// auto newOperandType = lastDimReduction ? -// getVectorzedType(newReadOp) -// : -// getScalarType(newReadOp); -// newReadOp->getResult(0).setType(newOperandType); -// setOpVectorizationPermutationMap( -// newReadOp, b, -// newReadOp.getSource().getType().dyn_cast(), -// newReadOp.getPermutationMap()); -// rewriter.replaceOp(readOp, newReadOp); -// if (lastDimReduction) { -// Operation *reductionOp = rewriter.create( -// loc, multiReductionOp.getKind(), newReadOp->getResult(0), -// loopState.back()); -// maybeYieldValue(b, loc, reductionOp->getResults()); -// } else { -// auto reductionResult = -// makeArithReduction(b, loc, multiReductionOp.getKind(), -// newReadOp->getResult(0), -// loopState.back()); -// maybeYieldValue(b, loc, reductionResult); -// } -// } else { -// // outter loop -// auto nxtFor = generateMultiReductionForLoop( -// b, multiReductionOp, parallelAxis, parallelIdx, -// reductionAxis, reductionIdx + 1, vectorType, inductionVars, -// iterArgs, originalWriteResult, lastDimReduction); -// maybeYieldValue(b, loc, nxtFor.value()->getResults()); -// } -// }); -// } -// return forOp; -// } - -// 1. Classify operaions: -// classify the operations into : -// a. reorder, transpose. Reorder(or transpose) dim may bring data -// dependency. -// b. elemenwise. Those operations can be fused into a common for loop. -// c. broadcast. Need to analysis broadcast dim and the data -// dependency. -// d. reduction. Need to analysis broadcast dim and the -// data dependency. -// Same group operations have no data dependencies. They can be fused into a -// common for loop body. - -// Using queue to store the operation order. In order to ensure that -// subsequent moves to the operation will not cause semantic changes. -class VectorFusionStrategy { -public: - llvm::SmallVector, 8> &getOpGroups() { - return opGroups; - } - llvm::DenseMap &getOpGroupIndexMap() { - return opGroupIndexMap; - } - - func::FuncOp getFunc() { return func; } - - VectorFusionStrategy() = default; - VectorFusionStrategy(func::FuncOp func) : func(func) {} - - void - classifyOperations(func::FuncOp func, - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap); - - // run the vector fusion strategy - void run(); +llvm::SmallVector & +MultiReductionCanonicalizer::getCandidateOps() { + return candidateRdOps; +}; -private: - llvm::SmallVector, 8> opGroups; - // query current operation in which group, return group index - llvm::DenseMap opGroupIndexMap; +void CanonicalizerVectorOperation::getCandidateSpecialOps() { + auto grp = commonUsedData.getFusionStrategy().getOpGroups(); + // avoid seg fault + auto multiRdCanonicalizer = commonUsedData.getMultiRdCanonicalizer(); + multiRdCanonicalizer.clear(); + size_t start = 0; + while (start++ < grp.size()) { + multiRdCanonicalizer.emplace_back(MultiReductionCanonicalizer({})); + } - func::FuncOp func; + auto idxGroup = commonUsedData.getFusionStrategy().getOpGroupIndexMap(); + func->walk([&](Operation *op) { + llvm::TypeSwitch(op) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + auto groupIdx = idxGroup[multiReductionOp]; + multiRdCanonicalizer[groupIdx].getCandidateOps().emplace_back( + multiReductionOp); + }) + .Case([&](vector::ShapeCastOp shapeCastOp) { + // shapeCastOps.insert(shapeCastOp); + // TODO + assert(0); + }) + .Case([&](vector::TransposeOp transposeOp) { + // transposeOps.insert(transposeOp); + // TODO + assert(0); + }) + .Default([&](Operation *) {}); + }); }; -void VectorFusionStrategy::run() { - classifyOperations(func, opGroups, opGroupIndexMap); +void MultiReductionCanonicalizer::initReductionAxis() { + auto reductionAxisRange = + candidateRdOps[0].getReductionDims().getAsValueRange(); + auto reductionRange = llvm::to_vector<4>(llvm::map_range( + reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); + reductionAxis.assign(reductionRange.begin(), reductionRange.end()); } -enum CanonicalizerKind { OperationsGroup, Operations }; - -class CanonicalizerVectorOperation { -public: - func::FuncOp func; - IRRewriter rewriter; - VectorFusionStrategy fusionStrategy; - CanonicalizerKind kind; - - // analysis the operation's operands and results - llvm::SmallVector, 8> groupOpResults, groupOpIterArgs; - - // store read and write operations permutation maps in order to convenient - // to replace loop induction var - llvm::DenseMap opPermuationMap; - - CanonicalizerVectorOperation( - func::FuncOp func, - CanonicalizerKind kind = CanonicalizerKind::OperationsGroup) - : func(func), rewriter(func), kind(kind) { - // vector operation fusion - if (kind == CanonicalizerKind::OperationsGroup) { - fusionStrategy = VectorFusionStrategy(func); - fusionStrategy.run(); +void MultiReductionCanonicalizer::initParallelAxis() { + llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), + reductionAxis.end()); + for (int64_t i = 0; i < typeRank; ++i) { + if (!reductionAxisSet.contains(i)) { + parallelAxis.push_back(i); } } - func::FuncOp getFunc() { return func; }; - - void generateGroupOpVectorizedIR( - const int idx, std::queue &grp, - llvm::DenseMap &opGroupIndexMap); - - void analysisGroupOperaionOperandsResults( - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap); + llvm::sort(parallelAxis.begin(), parallelAxis.end()); +} - void analysisGroupOperationResults( - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap); +int64_t MultiReductionCanonicalizer::getTypeRank() { + auto srcVecType = candidateRdOps[0].getSourceVectorType(); + auto srcRank = srcVecType.getRank(); + typeRank = srcRank; + return srcRank; +} - void canonicalizeSpecialOperation(); - LogicalResult - canonicalizeReductionOperation(vector::MultiDimReductionOp &multiReductionOp, - IRRewriter &rewriter); - LogicalResult canonicalizeTransposeOperation(vector::TransposeOp &transposeOp, - IRRewriter &rewriter); +void MultiReductionCanonicalizer::getReductionAxisAndParallelAxis() { + initReductionAxis(); + initParallelAxis(); +} - void run(); +bool MultiReductionCanonicalizer::hasLastDimReduction() { + llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), + reductionAxis.end()); + bool res = false; + if (reductionAxisSet.contains(typeRank - 1)) { + res = true; + } + haslastDimReduction = res; + return res; +} -private: - llvm::SetVector multiReductionOps; - llvm::SetVector shapeCastOps; +void MultiReductionCanonicalizer::prepareReductionInfo() { + getTypeRank(); + getReductionAxisAndParallelAxis(); + hasLastDimReduction(); }; -// LogicalResult CanonicalizerVectorOperation::canonicalizeTransposeOperation( -// vector::TransposeOp &transposeOp, IRRewriter &rewriter) { -// OpBuilder::InsertionGuard guard(rewriter); - -// auto srcVecType = multiReductionOp.getSourceVectorType(); -// auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - -// // Separate reduction and parallel dims -// bool lastDimReduction = false; -// auto reductionAxisRange = -// multiReductionOp.getReductionDims().getAsValueRange(); -// auto reductionRange = llvm::to_vector<4>(llvm::map_range( -// reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); -// llvm::SmallVector reductionAxis(reductionRange.begin(), -// reductionRange.end()); -// llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), -// reductionAxis.end()); -// if (reductionAxisSet.contains(srcRank - 1)) { -// lastDimReduction = true; -// } -// SmallVector parallelAxis; -// for (int64_t i = 0; i < srcRank; ++i) { -// if (!reductionAxisSet.contains(i)) { -// parallelAxis.push_back(i); -// } -// } -// /* -// * The final IR may look like below: -// * _for_(_fuseiter_i, 0, 1) -// * sum = 0; -// * _for_(_fuseiter_j, 0, 1) -// * _for_(_fuseiter_k, 0, 1) -// * sum += src[src_idx]; -// * dst[dst_idx] = sum; -// * */ -// Operation *newReduction; -// Value multiReductionAcc = multiReductionOp.getAcc(); -// auto accTensorReadOp = -// multiReductionAcc.getDefiningOp(); -// Value originalWriteResult; -// ValueRange iterArgs(accTensorReadOp->getOperand(0)); -// llvm::SmallVector inductionVars; -// auto forOp = generateMultiReductionForLoop( -// rewriter, multiReductionOp, parallelAxis, 0, reductionAxis, 0, -// srcVecType, inductionVars, iterArgs, originalWriteResult, -// lastDimReduction); -// if (failed(forOp)) { -// LDBG("MultiReduction Operation lowering failed"); -// return failure(); -// } -// auto replaceIfFn = [&](OpOperand &use) { -// return use.getOwner()->getBlock() != -// originalWriteResult.getDefiningOp()->getBlock(); -// }; -// newReduction = forOp.value(); -// rewriter.replaceOpUsesWithIf(originalWriteResult.getDefiningOp(), -// newReduction->getResults()[0], replaceIfFn); - -// rewriter.replaceOp(multiReductionOp, newReduction); -// return success(); -// } - -LogicalResult CanonicalizerVectorOperation::canonicalizeReductionOperation( - vector::MultiDimReductionOp &multiReductionOp, IRRewriter &rewriter) { +LogicalResult CanonicalizerVectorOperation::canonicalizeReductionOperation() { OpBuilder::InsertionGuard guard(rewriter); - auto srcVecType = multiReductionOp.getSourceVectorType(); - auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - - // Separate reduction and parallel dims - bool lastDimReduction = false; - auto reductionAxisRange = - multiReductionOp.getReductionDims().getAsValueRange(); - auto reductionRange = llvm::to_vector<4>(llvm::map_range( - reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); - llvm::SmallVector reductionAxis(reductionRange.begin(), - reductionRange.end()); - llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), - reductionAxis.end()); - if (reductionAxisSet.contains(srcRank - 1)) { - lastDimReduction = true; - } - SmallVector parallelAxis; - for (int64_t i = 0; i < srcRank; ++i) { - if (!reductionAxisSet.contains(i)) { - parallelAxis.push_back(i); + // traverse all groups + auto &multiRdCanonicalizer = commonUsedData.getMultiRdCanonicalizer(); + for (auto [groupId, rdCanonicalizer] : + llvm::enumerate(multiRdCanonicalizer)) { + auto &candidateOps = rdCanonicalizer.getCandidateOps(); + if (candidateOps.empty()) { + continue; } + // generate MultiReduction for loops + auto forOp = generateMultiReductionForLoop(groupId); + // update uses } - /* - * The final IR may look like below: - * _for_(_fuseiter_i, 0, 1) - * sum = 0; - * _for_(_fuseiter_j, 0, 1) - * _for_(_fuseiter_k, 0, 1) - * sum += src[src_idx]; - * dst[dst_idx] = sum; - * */ - Operation *newReduction; - Value multiReductionAcc = multiReductionOp.getAcc(); - auto accTensorReadOp = - multiReductionAcc.getDefiningOp(); - Value originalWriteResult; - ValueRange iterArgs(accTensorReadOp->getOperand(0)); - llvm::SmallVector inductionVars; - auto forOp = generateMultiReductionForLoop( - rewriter, multiReductionOp, parallelAxis, 0, reductionAxis, 0, srcVecType, - inductionVars, iterArgs, originalWriteResult, lastDimReduction); - auto replaceIfFn = [&](OpOperand &use) { - return use.getOwner()->getBlock() != - originalWriteResult.getDefiningOp()->getBlock(); - }; - newReduction = forOp; - rewriter.replaceOpUsesWithIf(originalWriteResult.getDefiningOp(), - newReduction->getResults()[0], replaceIfFn); + // Separate reduction and parallel dims + // Operation *newReduction; + // auto accSourceOp = multiReductionAcc.getDefiningOp(); + // llvm::SmallVector initIterArgs; + // // process Acc operand + // if (mlir::dyn_cast(accSourceOp)) { + // auto accTensorReadOp = + // multiReductionAcc.getDefiningOp(); + // initIterArgs.emplace_back(accTensorReadOp->getOperand(0)); + // } + // auto dstOperandSet = commonUsedData.getGroupOpIterArgs()[grpIdx]; + // llvm::SmallVector operands; + // llvm::DenseMap operandIdxMap; + // for (auto [idx, x] : llvm::enumerate(dstOperandSet)) { + // initIterArgs.emplace_back(x); + // operandIdxMap[x] = operands.size() - 1; + // } - rewriter.replaceOp(multiReductionOp, newReduction); + // Value originalWriteResult; + // ValueRange iterArgs(initIterArgs); + // llvm::SmallVector inductionVars; + // auto forOp = generateMultiReductionForLoop( + // rewriter, multiReductionOp, parallelAxis, 0, reductionAxis, 0, + // srcVecType, inductionVars, iterArgs, operandIdxMap, + // originalWriteResult, *this, lastDimReduction, isStandaloneOp); + // auto replaceIfFn = [&](OpOperand &use) { + // return use.getOwner()->getBlock() != + // originalWriteResult.getDefiningOp()->getBlock(); + // }; + // newReduction = forOp; + // rewriter.replaceOpUsesWithIf(originalWriteResult.getDefiningOp(), + // newReduction->getResults()[0], replaceIfFn); + + // rewriter.replaceOp(firstOp, newReduction); return success(); } void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { - func->walk([&](Operation *op) { - llvm::TypeSwitch(op) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - multiReductionOps.insert(multiReductionOp); - }) - .Case([&](vector::ShapeCastOp shapeCastOp) { - shapeCastOps.insert(shapeCastOp); - }) - .Default([&](Operation *) {}); - }); - // process reduction - for (auto x : multiReductionOps) { - IRRewriter rewriter(x); - (void)canonicalizeReductionOperation(x, rewriter); - } - // process shapecast - // for (auto x : shapeCastOps) { - // } - return; + // multireduction operation + auto result = canonicalizeReductionOperation(); + // canonicalizeBroadCastOperation(); } void CanonicalizerVectorOperation::run() { - + auto &fusionStrategy = commonUsedData.getFusionStrategy(); if (kind == CanonicalizerKind::OperationsGroup) { // 1. Analysis the operation's operands and results // We need to analyze which operation results are needed by other @@ -947,16 +1231,15 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. - analysisGroupOperaionOperandsResults(fusionStrategy.getOpGroups(), - fusionStrategy.getOpGroupIndexMap()); - + analysisGroupOperaionOperandsResults(); + std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + << " : " << fusionStrategy.getOpGroups().size() << std::endl; // Speical Operation Canonicalization - canonicalizeSpecialOperation(); + // canonicalizeSpecialOperation(); // 2.Generate vectorized IR for each operation group for (auto [idx, grp] : llvm::enumerate(fusionStrategy.getOpGroups())) { - generateGroupOpVectorizedIR(idx, grp, - fusionStrategy.getOpGroupIndexMap()); + generateGroupOpVectorizedIR(idx); } // 3. Some IR cleanup work @@ -981,6 +1264,15 @@ void CanonicalizerVectorOperation::run() { LDBG("Operation is not in vector dialect" << *op << "\n"); return false; } + + if (mlir::isa(op) || + mlir::isa(op)) { + if (!isReadWriteOnLastDim(op)) { + LDBG("Operation is not last dim read/write" << *op << "\n"); + return false; + } + } + return true; } @@ -1012,65 +1304,6 @@ void checkAndSetOperand( } } -// TODO: need to rewrite reduce operation as a performance forms like -// graph-compiler v1 -scf::ForOp constructReductionNestedForOp( - OpBuilder &b, const Location &loc, const ValueRange &iterArgs, - const VectorType &type, const llvm::ArrayRef &dims, size_t idx, - std::queue &queue, const llvm::SetVector &resultSet, - llvm::SmallVector &inductionVars, - const llvm::DenseMap &operandIdxMap, - const llvm::SmallVector &rdDims, - const llvm::DenseMap &opPermuationMap) { - const int loop_step = getDataTypeMAXSIMDLength(type); - - // loop initialization variable - auto zero = - b.create(b.getUnknownLoc(), b.getIndexType(), - b.getIntegerAttr(b.getIndexType(), 0)); - auto forSteps = b.create( - b.getUnknownLoc(), b.getIndexType(), - b.getIntegerAttr(b.getIndexType(), - idx == dims.size() - 1 ? loop_step : 1)); - auto numIter = b.create( - b.getUnknownLoc(), b.getIndexType(), - b.getIntegerAttr(b.getIndexType(), dims[idx])); - - // Create a loop and move vectorized operation into loops. - auto forOp = b.create( - b.getUnknownLoc(), zero, numIter, forSteps, iterArgs, - [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { - inductionVars.emplace_back(iv); - - // inner most body of the loop - if (idx == dims.size() - 1) { - Operation *lastOperation = queue.front(); - while (!queue.empty()) { - auto x = queue.front(); - queue.pop(); - if (lastOperation == x) { - x->moveBefore(b.getBlock(), b.getBlock()->begin()); - } else { - x->moveAfter(lastOperation); - lastOperation = x; - } - // check operation type to set correct operand - checkAndSetOperand(x, loopState, operandIdxMap, inductionVars, - opPermuationMap); - } - maybeYieldValue(b, loc, resultSet.getArrayRef()); - } else { - - // outter loop - auto nxtFor = constructReductionNestedForOp( - b, loc, loopState, type, dims, idx + 1, queue, resultSet, - inductionVars, operandIdxMap, rdDims, opPermuationMap); - maybeYieldValue(b, loc, nxtFor->getResults()); - } - }); - return forOp; -} - scf::ForOp constructNestedForOp( OpBuilder &b, const Location &loc, const ValueRange &iterArgs, const VectorType &type, const llvm::ArrayRef &dims, size_t idx, @@ -1209,20 +1442,31 @@ void getOperationDataAxis(Operation *op, llvm::SmallVector &dataAxis) { auto dstType = shapeCastOp.getResultVectorType(); auto srcShape = srcType.getShape(); auto dstShape = dstType.getShape(); - shapeCastSourceAxis(srcShape, dstShape, dataAxis); + if (srcShape.size() < dstShape.size()) { + shapeCastSourceAxis(srcShape, dstShape, dataAxis); + } else { + shapeCastSourceAxis(dstShape, srcShape, dataAxis); + } }) .Default([&](Operation *op) { // default is last axis dataAxis.emplace_back( - op->getResultTypes().front().cast().getRank() - 1); + mlir::dyn_cast(op->getResultTypes().front()).getRank() - + 1); }); } bool hasDataDependency(Operation *op1, Operation *op2) { + if (!isSpecialOp(op1) and !isSpecialOp(op2)) { + return false; + } // op1 must be special operation if (!isSpecialOp(op1)) { return hasDataDependency(op2, op1); } + if (isSpecialOp(op1)) { + return true; + } auto hasSameAxis = [](const llvm::SmallVector &dims1, const llvm::SmallVector &dims2) { llvm::DenseSet checkSet(dims2.begin(), dims2.end()); @@ -1254,11 +1498,16 @@ bool hasDataDependency(Operation *op1, Operation *op2) { getOperationDataAxis(op2, dims2); llvm::DenseSet checkSet(dims2.begin(), dims2.end()); - for (auto x : dims2) { - if (!checkSet.contains(x)) { - return true; + if (!isSpecialOp(op2)) { + for (auto x : reductionDims) { + if (!checkSet.contains(x)) { + return true; + } } + } else { + // TODO: reduce operation fused with other special operation } + return false; }) .Default([&](Operation *op) { return false; }); @@ -1275,24 +1524,16 @@ bool isNeedNewGroup(llvm::SmallVector, 8> &opGroups, if (prevOp->getParentOp() != op->getParentOp()) { return true; } - // previous operation is a special operation - if (isSpecialOp(prevOp)) { - // special operation need to check data dependency axis - if (hasDataDependency(prevOp, op)) { - return true; - } + // special operation need to check data dependency axis + if (hasDataDependency(prevOp, op)) { return true; } + // previous operation vector type is not compatible with current operation if (!isCompatibleVectorType(prevOp, op)) { return true; } } - - // 2. check current operation - if (isSpecialOp(op)) { - return true; - } return false; } @@ -1307,12 +1548,12 @@ void addOperationToGroup( opGroupIndexMap[op] = opGroups.size() - 1; } +bool isInitOperation(Operation *op) { return mlir::isa(op); } + // We classify the operations we are interested in after filtering. Operations // of in the same group have no data dependencies. Those operations can generate // a same outter for loop. -void VectorFusionStrategy::classifyOperations( - func::FuncOp func, llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap) { +void VectorFusionStrategy::classifyOperations() { if (opGroups.empty()) { // dummpy opGroups.emplace_back(std::queue()); @@ -1321,6 +1562,14 @@ void VectorFusionStrategy::classifyOperations( TypeSwitch(op).Default([&](Operation *op) { if (filterOperation(op)) { addOperationToGroup(opGroups, opGroupIndexMap, op); + // update init operation + } + while (ignoreInitOperations.size() < opGroups.size()) { + ignoreInitOperations.emplace_back(std::queue()); + } + // some init operations need to ignore + if (isInitOperation(op)) { + ignoreInitOperations.back().push(op); } }); }); @@ -1354,7 +1603,7 @@ void setOperationOperandResult( Operation *op, const VectorType &newOperandType, const llvm::DenseMap &opMap) { for (auto [idx, x] : llvm::enumerate(op->getOperands())) { - if (x.getType().dyn_cast()) { + if (mlir::dyn_cast(x.getType())) { if (!opMap.contains(x.getDefiningOp())) { auto result = setOutGroupOperationOperandResult(x.getDefiningOp(), newOperandType); @@ -1365,46 +1614,37 @@ void setOperationOperandResult( } } for (auto x : op->getResults()) { - if (x.getType().dyn_cast()) { + if (mlir::dyn_cast(x.getType())) { x.setType(newOperandType); } } }; -Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { - auto rtType = op->getResultTypes()[0].dyn_cast(); - auto rank = rtType.getRank(); - auto dstType = dest.getType().dyn_cast(); - IRRewriter reWriter(op); - - auto zero = - reWriter.create(reWriter.getUnknownLoc(), 0); - - reWriter.setInsertionPointAfter(op); - SmallVector inBoundsVal(rank, true); - - SmallVector shapes; - SmallVector dynDims; - for (unsigned i = 0; i < rtType.getRank(); i++) { - shapes.push_back(rtType.getDimSize(i)); - if (rtType.isDynamicDim(i)) { - dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), - op->getResult(0), i)); - } - } - return reWriter.create( - reWriter.getUnknownLoc(), - /*vector=*/op->getResult(0), - /*source=*/dest, - /*indices=*/SmallVector(dstType.getRank(), zero), - /*inBounds=*/inBoundsVal); +void createNewConstantOp( + Operation *srcOp, vector::TransferWriteOp *transferWriteOp, + llvm::DenseMap &opPermuationMap) { + IRRewriter srcWriter(srcOp); + auto newOperandType = getVectorzedType(mlir::cast(srcOp)); + auto srcConstantOp = mlir::dyn_cast(srcOp); + Operation *newConstantOp = srcWriter.create( + srcOp->getLoc(), srcConstantOp.getValueAttr()); + newConstantOp->getResult(0).setType(newOperandType); + transferWriteOp->setOperand(0, newConstantOp->getResult(0)); + opPermuationMap.insert( + {mlir::cast(srcOp), transferWriteOp->getPermutationMap()}); + setOpVectorizationPermutationMap( + mlir::cast(srcOp), srcWriter, + mlir::dyn_cast( + transferWriteOp->getResults()[0].getType()), + transferWriteOp->getPermutationMap()); } /// Rewrite the operations in the group to vectorized form. -void rewriteOperationAsVectorize( - const std::queue &groupOps, - const llvm::DenseMap &opMap, IRRewriter &rewriter, - llvm::DenseMap &opPermuationMap) { +void CanonicalizerVectorOperation::rewriteOperationAsVectorize( + OpBuilder &rewriter, size_t groupId) { + auto &groupOps = commonUsedData.getFusionStrategy().getOpGroups()[groupId]; + auto &opMap = commonUsedData.getFusionStrategy().getOpGroupIndexMap(); + auto &opPermuationMap = commonUsedData.getOpPermuationMap(); std::queue transformQueue(groupOps); while (!transformQueue.empty()) { @@ -1416,18 +1656,20 @@ void rewriteOperationAsVectorize( [&](vector::TransferWriteOp transferWriteOp) { IRRewriter rewriter(transferWriteOp); auto newOperandType = getVectorzedType(transferWriteOp); + auto srcOp = transferWriteOp->getOperand(0).getDefiningOp(); + if (mlir::isa(srcOp)) { + createNewConstantOp(srcOp, &transferWriteOp, + opPermuationMap); + } else if (!isSpecialOp(srcOp)) { - if (!isSpecialOp( - transferWriteOp->getOperand(0).getDefiningOp())) { + transferWriteOp->getOperand(0).setType(newOperandType); opPermuationMap.insert( {transferWriteOp, transferWriteOp.getPermutationMap()}); - transferWriteOp->getOperand(0).setType(newOperandType); setOpVectorizationPermutationMap( transferWriteOp, rewriter, - transferWriteOp->getResult(0) - .getType() - .dyn_cast(), + mlir::dyn_cast( + transferWriteOp->getResult(0).getType()), transferWriteOp.getPermutationMap()); } @@ -1450,9 +1692,8 @@ void rewriteOperationAsVectorize( transferReadOp->getResult(0).setType(newOperandType); setOpVectorizationPermutationMap( transferReadOp, rewriter, - transferReadOp.getSource() - .getType() - .dyn_cast(), + mlir::dyn_cast( + transferReadOp.getSource().getType()), transferReadOp.getPermutationMap()); } @@ -1460,8 +1701,15 @@ void rewriteOperationAsVectorize( }) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { - return success(); + llvm::llvm_unreachable_internal( + "It should not appear this operation."); + return failure(); }) + .Case([&](arith::ExtFOp extFop) { + auto newOperandType = getVectorzedType(extFop); + extFop->getResult(0).setType(newOperandType); + return success(); + }) .Default([&](Operation *op) { if (isSpecialOp(op)) { return success(); @@ -1476,48 +1724,7 @@ void rewriteOperationAsVectorize( } } -// analysis operation' operands are coming from which operation's result -// void analysisOperaionOperandSource( -// size_t idx, std::queue &grp, -// llvm::DenseMap &opGroupIndexMap, -// llvm::SmallVector, 8> &groupOperandNeedSet) { -// auto tmpOpQueue(grp); -// llvm::SetVector opOperands; -// while (!tmpOpQueue.empty()) { -// auto t = tmpOpQueue.front(); -// for (auto x : t->getOperands()) { -// // not in the same group -// if (opGroupIndexMap.contains(x.getDefiningOp()) && -// opGroupIndexMap[x.getDefiningOp()] != idx) { -// groupOperandNeedSet[idx].insert(x); -// } else { -// groupOperandNeedSet[idx].insert(x); -// } -// } -// tmpOpQueue.pop(); -// } -// } - -// canonicalizing operation as tensor empty and transfer write the operation -// result into the empty tensor -[[nodiscard]] std::pair -canonicalizeSourceOperation(Operation *op) { - auto emtpyOp = createTensorEmptyBefore(op); - auto writeOp = createTransferWriteOpAfter(op, emtpyOp->getResults()[0]); - return std::make_pair(emtpyOp->getResults()[0], writeOp->getResults()[0]); -} - -[[nodiscard]] Value -canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, - size_t operandIdx, - vector::TransferReadOp *srcReadOp = nullptr) { - // transfer_read operation - auto readOp = createTransferReadOpBefore(op, transferReadOperand, srcReadOp); - op->setOperand(operandIdx, readOp->getResults()[0]); - return readOp->getResults()[0]; -} - -mlir::FailureOr getOperationDestnationOperand(Operation *op) { +mlir::FailureOr getOperationOperateTensor(Operation *op) { return llvm::TypeSwitch>(op) .Case( [&](vector::TransferWriteOp transferWriteOp) { @@ -1534,29 +1741,6 @@ mlir::FailureOr getOperationDestnationOperand(Operation *op) { }); } -// analysis operations of current group need which operation's result value -// void analysisGroupOperationOperands( -// llvm::SmallVector, 8> &opGroups, -// llvm::DenseMap &opGroupIndexMap, -// llvm::SmallVector, 8> &groupOperandNeedSet) { - -// for (auto [idx, grp] : enumerate(opGroups)) { -// analysisOperaionOperandSource(idx, grp, opGroupIndexMap, -// groupOperandNeedSet); -// } -// } - -// TODO: need to rewrite reduce -// llvm::SmallVector & -// getreductionAxis(vector::MultiDimReductionOp &reductionOp, -// llvm::SmallVector &rdDims) { -// auto rdDimsAttr = reductionOp.getreductionAxis().getValue(); -// for (auto x : rdDimsAttr) { -// rdDims.emplace_back(x.cast().getInt()); -// } -// return rdDims; -// } - void updateOpOperandResultInGroups( llvm::SmallVector, 8> &opGroups, llvm::DenseMap &opGroupIndexMap, size_t opGid, @@ -1585,33 +1769,44 @@ void updateOpOperandResultInGroups( opGroups[opGid] = newOpQueue; } +void VectorFusionStrategy::run() { classifyOperations(); } + // analysis operation result of current group whether needed by other // operation which out of current group -void CanonicalizerVectorOperation::analysisGroupOperationResults( - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap) { +void CanonicalizerVectorOperation::analysisGroupOperationResults() { llvm::DenseMap> srcOpCanoniclizedMap; - IRRewriter rewriter(func); + auto &commonUsedData = getCommonUsedData(); + auto &opGroups = commonUsedData.getFusionStrategy().getOpGroups(); + auto &opGroupIndexMap = + commonUsedData.getFusionStrategy().getOpGroupIndexMap(); + auto &groupOpIterArgs = commonUsedData.getGroupOpIterArgs(); + auto &groupOpResults = commonUsedData.getGroupOpResults(); func.walk([&](Operation *op) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { auto sourceOp = opd.getDefiningOp(); + if (opGroupIndexMap.contains(sourceOp)) { auto sourceOpGid = opGroupIndexMap[sourceOp]; - // bool notInSameGroup = opGroupIndexMap.contains(op) && sourceOpGid != opGroupIndexMap[op]; bool outOfGroup = !opGroupIndexMap.contains(op); + if (notInSameGroup or outOfGroup) { // update init iterargs - auto dstRet = getOperationDestnationOperand(sourceOp); + auto dstRet = getOperationOperateTensor(sourceOp); + // need to generate tensor.emtpy and vector.transfer_write, write + // operand to tensor and read operand from the tensor, generate + // vector.transfer_read if (failed(dstRet)) { + // already generate result tensor if (!srcOpCanoniclizedMap.contains(sourceOp)) { - auto [init, result] = canonicalizeSourceOperation(sourceOp); - srcOpCanoniclizedMap.insert({sourceOp, {init, result}}); + auto [resultTensor, result] = + canonicalizeSourceOperation(sourceOp); + srcOpCanoniclizedMap.insert({sourceOp, {resultTensor, result}}); updateOpOperandResultInGroups(opGroups, opGroupIndexMap, - sourceOpGid, sourceOp, init, + sourceOpGid, sourceOp, resultTensor, result); - groupOpIterArgs[sourceOpGid].insert(init); + groupOpIterArgs[sourceOpGid].insert(resultTensor); groupOpResults[sourceOpGid].insert(result); } @@ -1621,6 +1816,8 @@ void CanonicalizerVectorOperation::analysisGroupOperationResults( opGroupIndexMap[op], op, opInit); } else { + // if source operation is transfer_read, we need to generate a same + // transfer_read operation like source operation. if (mlir::isa(sourceOp)) { auto transferReadOp = mlir::dyn_cast(sourceOp); @@ -1648,23 +1845,13 @@ void CanonicalizerVectorOperation::analysisGroupOperationResults( LDBG("Complete analysis group operation results\n"); } -void CanonicalizerVectorOperation::analysisGroupOperaionOperandsResults( - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap) { - // prepare - if (opGroups.size() != groupOpResults.size()) { - for (size_t i = 0; i < opGroups.size(); i++) { - groupOpResults.emplace_back(llvm::SetVector()); - groupOpIterArgs.emplace_back(llvm::SetVector()); - } - LDBG("Size of groupOpResults is : " << groupOpResults.size()); - } +void CanonicalizerVectorOperation::analysisGroupOperaionOperandsResults() { // Operands // analysisGroupOperationOperands(opGroups, opGroupIndexMap); // Results - analysisGroupOperationResults(opGroups, opGroupIndexMap); + analysisGroupOperationResults(); } mlir::FailureOr generateVectorizedForLoop( @@ -1728,8 +1915,7 @@ bool hasSpecialOperation(std::queue &grp) { std::queue tmpQ(grp); while (!tmpQ.empty()) { auto curOp = tmpQ.front(); - if (mlir::isa(curOp) or - mlir::isa(curOp)) { + if (isSpecialOp(curOp)) { return true; } tmpQ.pop(); @@ -1737,9 +1923,8 @@ bool hasSpecialOperation(std::queue &grp) { return false; } -void CanonicalizerVectorOperation::generateGroupOpVectorizedIR( - const int idx, std::queue &grp, - llvm::DenseMap &opGroupIndexMap) { +void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { + auto &grp = commonUsedData.getFusionStrategy().getOpGroups()[idx]; if (grp.empty()) { LDBG("Current operation Group is empty."); return; @@ -1748,6 +1933,9 @@ void CanonicalizerVectorOperation::generateGroupOpVectorizedIR( if (hasSpecialOperation(grp)) { return; } + auto &groupOpResults = commonUsedData.getGroupOpResults(); + auto &opPermuationMap = commonUsedData.getOpPermuationMap(); + auto &groupOpIterArgs = commonUsedData.getGroupOpIterArgs(); auto getType = getOperationVectorType(grp.front()); if (failed(getType)) { LDBG("Failed to get vector type for operation: " << *grp.front() << "\n"); @@ -1757,8 +1945,16 @@ void CanonicalizerVectorOperation::generateGroupOpVectorizedIR( IRRewriter rewriter(grp.back()); rewriter.setInsertionPointAfter(grp.back()); // 1. Rewrite operation as vectorized form - rewriteOperationAsVectorize(grp, opGroupIndexMap, rewriter, opPermuationMap); + rewriteOperationAsVectorize(rewriter, idx); // 2. Generate loop + // 2.a more init operation before current group operations + // auto firstGrpOp = grp.front(); + // while (!fusionStrategy.getIgnoreInitOperations()[idx].empty()) { + // auto initOp = fusionStrategy.getIgnoreInitOperations()[idx].front(); + // initOp->moveBefore(firstGrpOp); + // fusionStrategy.getIgnoreInitOperations()[idx].pop(); + // } + // 2.b generate common outter for loop auto forOp = generateVectorizedForLoop(rewriter, groupOpResults[idx], groupOpIterArgs[idx], opShapes, grp, opPermuationMap); @@ -1781,7 +1977,8 @@ struct CPUPhysicalRegisterPass RewritePatternSet patterns(ctx); auto func = getOperation(); // canonicalize vector operation, default use vector-based fusion strategy. - CanonicalizerVectorOperation canonicalizer(func); + CanonicalizerVectorOperation canonicalizer( + func, CanonicalizerKind::OperationsGroup); canonicalizer.run(); } }; diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index f60e94e75..42dd19bd0 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -9,19 +9,27 @@ #include "gc/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include namespace mlir { namespace gc { @@ -48,48 +56,83 @@ bool is_innermost_ir(Operation *op) { /// Need to check if the reassociation are static/constant. LogicalResult lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp) { + // + auto outputShape = expandOp.getStaticOutputShape(); + if (llvm::any_of(outputShape, + [](int64_t x) { return x == ShapedType::kDynamic; })) { + LDBG("Output shape must be static: " << expandOp << "\n"); + return failure(); + } + + return success(); +} + +LogicalResult lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp) { + if (bitCastOp.getSource().getType().getNumDynamicDims()) { + LDBG("Type must be static: " << bitCastOp << "\n"); + return failure(); + } + return success(); +} + +/// Need to check if the reassociation are static/constant. +LogicalResult +lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp expandOp) { - // if (llvm::any_of(expandOp.getReassociation(), [](ArrayAttr res) { - // if (llvm::any_of(res, [](Attribute x) { - // return !getConstantIntValue(x).has_value(); - // })) { - // return false; - // } - // return true; - // })) { - // LDBG("Reassociation must be constant: " << expandOp << "\n"); - // return failure(); - // } + if (llvm::any_of(expandOp.getReassociation(), [](Attribute x) { + return !getConstantIntValue(x).has_value(); + })) { + LDBG("Reassociation must be constant: " << expandOp << "\n"); + return failure(); + } + return success(); +} + +LogicalResult lowerConcatOpPrecondition(tensor::ConcatOp concatOp) { + for (auto x : concatOp->getOperands()) { + auto tensorType = mlir::dyn_cast(x.getType()); + if (!tensorType) { + LDBG("Operation type error: " << concatOp << "\n"); + return failure(); + } + if (tensorType.getNumDynamicDims()) { + LDBG("Type must be static: " << concatOp << "\n"); + return failure(); + } + } return success(); } -LogicalResult lowerTargetOpPrecondition(Operation *op, - ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims, - bool vectorizeNDExtract, - bool flatten1DDepthwiseConv) { +LogicalResult lowerTargetOpPrecondition(Operation *op) { return TypeSwitch(op) .Case([&](auto expandShapeOp) { return lowerExpandOpPrecondition(expandShapeOp); }) - .Case( - [&](auto collapseShapeOp) { return success(); }) - .Case([&](auto collapseShapeOp) { return success(); }) - .Case([&](auto concatOp) { return success(); }) + .Case([&](auto collapseShapeOp) { + return lowerCollapseShapeOpPrecondition(collapseShapeOp); + }) + .Case( + [&](auto bitCastOp) { return lowerBitcastOpPrecondition(bitCastOp); }) + .Case( + [&](auto concatOp) { return lowerConcatOpPrecondition(concatOp); }) .Default([](auto) { return failure(); }); } /// Create a TransferReadOp from `source` with static shape `readShape`. Value createTransferRead(OpBuilder &builder, Location loc, Value source, - ArrayRef readShape, Value padValue) { + ArrayRef readShape) { assert(llvm::none_of(readShape, [](int64_t s) { return s == ShapedType::kDynamic; })); assert(source && " source null."); - auto sourceShape = dyn_cast(source.getType()).getShape(); + auto shapedType = mlir::dyn_cast(source.getType()); + auto sourceShape = shapedType.getShape(); + auto vectorType = VectorType::get(readShape, shapedType.getElementType()); + + auto padValue = builder.create( + loc, builder.getZeroAttr(shapedType.getElementType())); assert(sourceShape.size() == readShape.size()); - auto vectorType = VectorType::get(readShape, padValue.getType()); int64_t readRank = readShape.size(); auto zero = builder.create(loc, 0); SmallVector inBoundsVal(readRank, true); @@ -108,7 +151,6 @@ Value createTransferRead(OpBuilder &builder, Location loc, Value source, } } -/// Given an input, the mixed destSizes, and the vector sizes for vectorization, /// create an empty destination tensor and create a TransferWriteOp from the /// input to the empty tensor. Operation *createTransferWrite(OpBuilder &builder, Location loc, Value input, @@ -129,7 +171,7 @@ Operation *createTransferWrite(OpBuilder &builder, Location loc, Value input, assert(llvm::none_of( destShape.drop_front(inputVectorSizes.size()), [](int64_t size) { return size == ShapedType::kDynamic; }) && - "Only dims aligned with inputVectorSizes may be dynamic"); + "InputVectorSizes may be dynamic"); return write; } @@ -140,20 +182,16 @@ Operation *createTransferWrite(OpBuilder &builder, Location loc, Value input, /// tensor template LogicalResult lowerTensorExpandShapeOp(RewriterBase &rewriter, T expandShapeOp, - ArrayRef inputVectorSizes, SmallVectorImpl &newResults) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(expandShapeOp); RankedTensorType expandShapeTensorType = expandShapeOp.getSrcType(); - SmallVector readMaskShape(inputVectorSizes.begin(), - inputVectorSizes.end()); + SmallVector readMaskShape; ArrayRef sourceShape = expandShapeTensorType.getShape(); ArrayRef resultShape = expandShapeOp.getResultType().getShape(); - - readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(), - sourceShape.end()); + readMaskShape.append(sourceShape.begin(), sourceShape.end()); ReifiedRankedShapedTypeDims reifiedRetShapes; LogicalResult status = @@ -165,14 +203,11 @@ LogicalResult lowerTensorExpandShapeOp(RewriterBase &rewriter, T expandShapeOp, } Location loc = expandShapeOp->getLoc(); - auto padValue = rewriter.create( - loc, rewriter.getZeroAttr(expandShapeTensorType.getElementType())); - // Read result, mask if necessary. If transferReadOp shape is not equal // to shape of source, then a mask is necessary. Value readResult = createTransferRead( rewriter, loc, expandShapeOp.getSrc(), - ArrayRef(readMaskShape.begin(), readMaskShape.end()), padValue); + ArrayRef(readMaskShape.begin(), readMaskShape.end())); auto resultVectorType = VectorType::get(resultShape, expandShapeTensorType.getElementType()); @@ -180,9 +215,7 @@ LogicalResult lowerTensorExpandShapeOp(RewriterBase &rewriter, T expandShapeOp, rewriter.create(loc, resultVectorType, readResult); SmallVector writeMaskShape( - expandShapeOp.getResultType().hasStaticShape() - ? inputVectorSizes - : shapeCastOp.getResultVectorType().getShape()); + shapeCastOp.getResultVectorType().getShape()); Operation *write = createTransferWrite(rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0], writeMaskShape); newResults.push_back(write->getResult(0)); @@ -196,7 +229,6 @@ LogicalResult lowerTensorExpandShapeOp(RewriterBase &rewriter, T expandShapeOp, /// tensor LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, tensor::BitcastOp bitCastOp, - ArrayRef inputVectorSizes, SmallVectorImpl &newResults) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(bitCastOp); @@ -206,35 +238,29 @@ LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, auto resultType = bitCastOp.getResult().getType(); auto resultShape = resultType.getShape(); - SmallVector readMaskShape(inputVectorSizes.begin(), - inputVectorSizes.end()); - - readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(), - sourceShape.end()); - + SmallVector readMaskShape; + readMaskShape.append(sourceShape.begin(), sourceShape.end()); Location loc = bitCastOp->getLoc(); - auto padValue = rewriter.create( - loc, rewriter.getZeroAttr(sourceType.getElementType())); - // Read result, mask if necessary. If transferReadOp shape is not equal // to shape of source, then a mask is necessary. Value readResult = createTransferRead( rewriter, loc, bitCastOp->getOperand(0), - ArrayRef(readMaskShape.begin(), readMaskShape.end()), padValue); + ArrayRef(readMaskShape.begin(), readMaskShape.end())); auto resultVectorType = VectorType::get(resultShape, resultType.getElementType()); vector::BitCastOp vectorbitCastOp = rewriter.create(loc, resultVectorType, readResult); - Value zero = rewriter.create(loc, 0); - SmallVector indices(resultType.getRank(), zero); - Value dest = rewriter.create(loc, resultShape, - resultType.getElementType()); - Operation *write = rewriter.create( - loc, vectorbitCastOp, dest, indices, - rewriter.getMultiDimIdentityMap(resultType.getRank())); + SmallVector writeMaskShape( + vectorbitCastOp.getResultVectorType().getShape()); + llvm::SmallVector destSizes; + for (auto size : resultShape) + destSizes.emplace_back(rewriter.getIndexAttr(size)); + auto write = + createTransferWrite(rewriter, loc, vectorbitCastOp->getResults()[0], + destSizes, writeMaskShape); newResults.push_back(write->getResults()[0]); return success(); } @@ -247,7 +273,6 @@ LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, /// tensor. LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, tensor::ConcatOp concatOp, - ArrayRef inputVectorSizes, SmallVectorImpl &newResults) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(concatOp); @@ -263,14 +288,11 @@ LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, return failure(); // Compute the partial sums for the slice offsets. - int64_t dim = concatOp.getDim(); Value dimValue = rewriter.create(loc, rewriter.getIndexAttr(dim)); int64_t rank = concatOp.getResultType().getRank(); - SmallVector strides(rank, rewriter.getIndexAttr(1)); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); // Construct the chain of insert_slice ops into the destination. Value result = *dest; @@ -279,19 +301,14 @@ LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, SmallVector sizes = tensor::getMixedSizes(rewriter, loc, input); - SmallVector readMaskShape(inputVectorSizes.begin(), - inputVectorSizes.end()); + SmallVector readMaskShape; auto inputType = llvm::cast(input.getType()); auto sourceShape = inputType.getShape(); - readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(), - sourceShape.end()); - auto padValue = rewriter.create( - loc, rewriter.getZeroAttr(inputType.getElementType())); + readMaskShape.append(sourceShape.begin(), sourceShape.end()); Value readResult = createTransferRead( rewriter, loc, input, - ArrayRef(readMaskShape.begin(), readMaskShape.end()), - padValue); + ArrayRef(readMaskShape.begin(), readMaskShape.end())); Value zero = rewriter.create(loc, 0); SmallVector indices(rank, zero); indices[dim] = previous_offset; @@ -309,28 +326,11 @@ LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, return success(); } -/// Emit a suitable vector form for an operation. If provided, -/// `inputVectorSizes` are used to vectorize this operation. -/// `inputVectorSizes` must match the rank of the iteration space of the -/// operation and the input vector sizes must be greater than or equal to -/// their counterpart iteration space sizes, if static. `inputVectorShapes` -/// also allows the vectorization of operations with dynamic shapes. -LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op, - ArrayRef inputVectorSizes, - ArrayRef inputScalableVecDims, - bool vectorizeNDExtract, - bool flatten1DDepthwiseConv) { +LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op) { LDBG("Attempting to vectorize:\n" << *op << "\n"); - LDBG("Input vector sizes: "); - LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - LDBG("Input scalable vector dims: "); - LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - - if (failed(lowerTargetOpPrecondition(op, inputVectorSizes, - inputScalableVecDims, vectorizeNDExtract, - flatten1DDepthwiseConv))) { + + if (failed(lowerTargetOpPrecondition(op))) { + std::cout << "FAILED TO LOWER TARGET OP\n" << std::endl; LDBG("Vectorization pre-conditions failed\n"); return failure(); } @@ -339,20 +339,18 @@ LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op, auto lowerResult = TypeSwitch(op) .Case([&](auto expandShapeOp) { - return lowerTensorExpandShapeOp(rewriter, expandShapeOp, - inputVectorSizes, results); + return lowerTensorExpandShapeOp( + rewriter, expandShapeOp, results); }) .Case([&](auto collapseShapeOp) { - return lowerTensorExpandShapeOp(rewriter, collapseShapeOp, - inputVectorSizes, results); + return lowerTensorExpandShapeOp( + rewriter, collapseShapeOp, results); }) .Case([&](auto bitCastOp) { - return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes, - results); + return lowerTensorBitcastOp(rewriter, bitCastOp, results); }) .Case([&](auto concatOp) { - return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes, - results); + return lowerTensorConcatOp(rewriter, concatOp, results); }) .Default([](auto) { return failure(); }); @@ -376,47 +374,35 @@ bool is_required_tensorOp(Operation *operation) { llvm::isa(operation); } -struct LinalgConvertTileVectorPass : public RewritePattern { - - explicit LinalgConvertTileVectorPass(MLIRContext *context, - bool vectorizeExtract = false, - bool flatten1DDepthwiseConv = false) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - - auto linalgOp = llvm::dyn_cast(op); - if (!linalgOp || !is_innermost_ir(op)) - return rewriter.notifyMatchFailure(op, "Not expected operations."); - - return linalg::vectorize(rewriter, op, /*inputVectorSizes=*/{}, - /*scalableVecDims=*/{}, true, false); - } -}; +template +struct OperationConvertTileVectorPass : public RewritePattern { -struct TensorPackConvertVectorPass : public RewritePattern { + explicit OperationConvertTileVectorPass(MLIRContext *context, + bool vectorizeNDExtract = false, + bool flatten1DDepthwiseConv = false) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + vectorizeNDExtract(vectorizeNDExtract), + flatten1DDepthwiseConv(flatten1DDepthwiseConv) {} - explicit TensorPackConvertVectorPass(MLIRContext *context, - bool vectorizeExtract = false, - bool flatten1DDepthwiseConv = false) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - tensor::PackOp tensorPackOp = dyn_cast(op); - if (!tensorPackOp || !is_innermost_ir(op)) + auto targetOp = llvm::dyn_cast(op); + if (!targetOp || !is_innermost_ir(op)) return rewriter.notifyMatchFailure(op, "Not expected operations."); return linalg::vectorize(rewriter, op, /*inputVectorSizes=*/{}, - /*scalableVecDims=*/{}, true, false); + /*scalableVecDims=*/{}, vectorizeNDExtract, + flatten1DDepthwiseConv); } + +private: + bool vectorizeNDExtract, flatten1DDepthwiseConv; }; struct TensorUnpackConvertVectorPass : public RewritePattern { - explicit TensorUnpackConvertVectorPass(MLIRContext *context, - bool vectorizeExtract = false, - bool flatten1DDepthwiseConv = false) + explicit TensorUnpackConvertVectorPass(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -433,10 +419,10 @@ struct TensorUnpackConvertVectorPass : public RewritePattern { llvm::ArrayRef inputShape = resultTy.getShape(); std::vector targetVectorSizes = inputShape.vec(); - llvm::SmallVector targetVecDims(inputShape.size(), false); + llvm::SmallVector targetVecDims(inputShape.size(), false); return linalg::vectorize(rewriter, op, /*inputVectorSizes=*/targetVectorSizes, - /*scalableVecDims=*/targetVecDims, true, false); + /*scalableVecDims=*/targetVecDims, false, false); } }; @@ -453,16 +439,16 @@ struct TensorOpConvertVectorPass : public RewritePattern { if (!is_target || !is_innermost_ir(op)) return rewriter.notifyMatchFailure(op, "Not expected operations."); - return convert2TargetOperation(rewriter, op, /*inputVectorSizes=*/{}, - /*scalableVecDims=*/{}, true, false); + return convert2TargetOperation(rewriter, op); } }; /// Pass that lower to tile vector. void populateLowerToTileVectorPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add, + OperationConvertTileVectorPass>( + patterns.getContext()); patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); } @@ -472,23 +458,29 @@ struct LowerTileVectorPass // auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - tensor::populateRewriteAsConstantPatterns(patterns); + + tensor::ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + return producer && producer->hasOneUse(); + }; + tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn); tensor::populateReassociativeReshapeFoldingPatterns(patterns); - populateLowerToTileVectorPatterns(patterns); - linalg::populatePadOpVectorizationPatterns(patterns); tensor::populateFoldTensorSubsetOpPatterns(patterns); tensor::populateFoldTensorEmptyPatterns(patterns, true); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - // vector::VectorTransformsOptions vectorTransformOptions; - // vector::populateVectorMultiReductionLoweringPatterns( - // patterns, vectorTransformOptions.vectorMultiReductionLowering); - // vector::populateVectorShapeCastLoweringPatterns(patterns); - // vector::VectorTransformsOptions options; - // options.vectorTransposeLowering = - // vector::VectorTransposeLowering::Shuffle16x16; - // vector::populateVectorTransposeLoweringPatterns(patterns, options); + populateLowerToTileVectorPatterns(patterns); + linalg::populatePadOpVectorizationPatterns(patterns); + + // vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + vector::populateSinkVectorBroadcastPatterns(patterns); + vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); + vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + auto curOp = getOperation(); + IRRewriter reWriter(curOp); + DominanceInfo domInfo(curOp); + eliminateCommonSubExpressions(reWriter, domInfo, curOp); } }; } // namespace From 9e4364bde69ca0bb96fe9ecac968be9fa61853cb Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 27 Jun 2024 18:13:47 +0800 Subject: [PATCH 10/66] record --- include/gc/Transforms/TilingVector.h | 107 ++- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 790 +++++++++++------- lib/gc/Transforms/LowerTileVectorPass.cpp | 2 +- 3 files changed, 596 insertions(+), 303 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index 0227a3794..a280999d2 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -41,11 +41,13 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include +#include #include #include #include #include #include +#include #include #include namespace mlir { @@ -53,10 +55,6 @@ namespace gc { namespace { Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); -void rewriteOperationAsVectorize( - const std::queue &groupOps, - const llvm::DenseMap &opMap, OpBuilder &rewriter, - llvm::DenseMap &opPermuationMap); void checkAndSetOperand( Operation *op, const ValueRange &iterArgs, const llvm::DenseMap &operandIdxMap, @@ -113,24 +111,40 @@ class VectorFusionStrategy { enum CanonicalizerKind { OperationsGroup, Operations }; -class MultiReductionCanonicalizer { +template class SpecialOperationCanonicalizer { +private: + llvm::SmallVector candidateRdOps; + +public: + SpecialOperationCanonicalizer() = default; + SpecialOperationCanonicalizer(const llvm::SmallVector &candidateRdOps) + : candidateRdOps(candidateRdOps) {} + llvm::SmallVector &getCandidateOps(); + virtual void prepareSpecialOperationInfo() = 0; +}; + +class MultiReductionCanonicalizer + : virtual public SpecialOperationCanonicalizer< + vector::MultiDimReductionOp> { private: - llvm::SmallVector candidateRdOps; llvm::SmallVector reductionAxis, parallelAxis; + std::queue prevOps, postOps, accRelatedOps, sourceRelatedOps; bool haslastDimReduction = false; bool isStandaloneOp = false; int64_t typeRank = -1; + llvm::SetVector originalOpResults; + VectorType sourceType, accType; + llvm::SmallDenseMap resultIdxMap; public: MultiReductionCanonicalizer( const llvm::SmallVector &candidateRdOps) - : candidateRdOps(candidateRdOps) { - assert(candidateRdOps.size() > 1); + : SpecialOperationCanonicalizer( + candidateRdOps) { isStandaloneOp = candidateRdOps.size() == 1; - prepareReductionInfo(); + prepareSpecialOperationInfo(); }; int64_t getTypeRank(); - llvm::SmallVector &getCandidateOps(); void getReductionAxisAndParallelAxis(); bool hasLastDimReduction(); bool getIsStandAloneOp() { return isStandaloneOp; } @@ -138,7 +152,38 @@ class MultiReductionCanonicalizer { void initParallelAxis(); llvm::SmallVector &getReductionAxis() { return reductionAxis; }; llvm::SmallVector &getParallelAxis() { return parallelAxis; }; - void prepareReductionInfo(); + std::queue &getPrevOps() { return prevOps; } + std::queue &getPostOps() { return postOps; } + std::queue &getAccRelatedOps() { return accRelatedOps; } + std::queue &getSourceRelatedOps() { return sourceRelatedOps; } + llvm::SetVector &getOriginalOpResults() { return originalOpResults; } + VectorType &getSourceType() { return sourceType; }; + VectorType &getAccType() { return accType; }; + llvm::SmallDenseMap &getResultIdxMap() { return resultIdxMap; } + void setResultIdxMap(const llvm::SmallDenseMap &map) { + resultIdxMap = std::move(map); + } + void prepareSpecialOperationInfo() override; +}; + +class BroadcastCanonicalizer + : virtual public SpecialOperationCanonicalizer { +private: +public: + BroadcastCanonicalizer( + const llvm::SmallVector &candidateBcOps) + : SpecialOperationCanonicalizer(candidateBcOps){}; + void prepareSpecialOperationInfo() override {} +}; + +class TransposeCanonicalizer + : virtual public SpecialOperationCanonicalizer { +private: +public: + TransposeCanonicalizer( + const llvm::SmallVector &candidateTpOps) + : SpecialOperationCanonicalizer(candidateTpOps){}; + void prepareSpecialOperationInfo() override {} }; class CanonicalizerCommonUsedData { @@ -150,7 +195,9 @@ class CanonicalizerCommonUsedData { // store read and write operations permutation maps in order to convenient // to replace loop induction var llvm::DenseMap opPermuationMap; - llvm::SmallVector multiRdCanonicalizer; + llvm::SmallVector multiRdCanonicalizers; + llvm::SmallVector broadcastCanonicalizers; + llvm::SmallVector transposeCanonicalizers; public: CanonicalizerCommonUsedData() = default; @@ -205,9 +252,21 @@ class CanonicalizerCommonUsedData { llvm::DenseMap &getOpPermuationMap() { return opPermuationMap; } + llvm::SmallVector &getMultiRdCanonicalizer() { - return multiRdCanonicalizer; + return multiRdCanonicalizers; + } + + llvm::SmallVector &getBroadcastCanonicalizer() { + return broadcastCanonicalizers; + } + + llvm::SmallVector &getTransposeCanonicalizer() { + return transposeCanonicalizers; } + + // other methods + void initSpeicalOperationCanonicalizers(); }; class CanonicalizerVectorOperation { @@ -239,25 +298,35 @@ class CanonicalizerVectorOperation { void analysisGroupOperaionOperandsResults(); + void generateEmptyTensorAndWrite( + Operation *sourceOp, llvm::DenseMap> + &srcOpCanoniclizedMap); void analysisGroupOperationResults(); LogicalResult canonicalizeReductionOperation(); LogicalResult canonicalizeTransposeOperation(vector::TransposeOp &transposeOp, IRRewriter &rewriter); - void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId); + void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, + const std::queue &queue = {}); // special operation methods scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); void getCandidateSpecialOps(); void canonicalizeSpecialOperation(); - scf::ForOp parallelAxisGenerateForLoop( - const int groupIdx, const int parallelIdx, ValueRange &initArgs, - llvm::SmallVector &inductionVars, Value &originalWriteResult); + scf::ForOp - reductionAxisGenerateForLoop(const int groupIdx, const size_t reductionIdx, - ValueRange &initArgs, + parallelAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, + const size_t parallelIdx, ValueRange &initArgs, + llvm::SmallVector &inductionVars, + Value &originalWriteResult); + + scf::ForOp + reductionAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, + const size_t reductionIdx, ValueRange &initArgs, llvm::SmallVector &inductionVars); + bool isGroupHasSpecialOperation(const size_t grpIdx); + void run(); }; } // namespace diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index ba4489476..008051245 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -7,11 +7,6 @@ // //===----------------------------------------------------------------------===// #include "gc/Transforms/TilingVector.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/SmallVector.h" namespace mlir { namespace gc { @@ -22,6 +17,8 @@ namespace { #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") +// TODO: remove it in the future +bool disableSpecialOp = true; struct HardwareInfo { bool favx512f = true; @@ -578,7 +575,7 @@ void getOpSourceOps(Operation *op, llvm::DenseSet &srcOps) { auto accOperand = srcOperandsQueue.front(); srcOperandsQueue.pop_front(); auto accOperandOp = accOperand.getDefiningOp(); - if (visited.count(accOperandOp)) { + if (!accOperandOp or visited.count(accOperandOp)) { continue; } visited.insert(accOperandOp); @@ -637,7 +634,7 @@ void classifySourceRelatedOps(std::queue &accRelatedOps, while (!prevOps.empty()) { auto op = prevOps.front(); prevOps.pop(); - if (isSrcRelated(srcOps, op)) { + if (isSrcRelated(srcOps, op) or op == srcOp) { sourceRelatedOps.push(op); } else { accRelatedOps.push(op); @@ -708,27 +705,28 @@ makeNewTransferWriteOp(Value source, IRMapping &writeMap, OpBuilder &b, return newWriteOp; } -Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x) { +Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, + int64_t x) { return opBuilder.create( loc, opBuilder.getIndexType(), opBuilder.getIntegerAttr(opBuilder.getIndexType(), x)); } void moveOperationsToCurrentForBody( - std::queue opQueue, OpBuilder &b, ValueRange loopState, + std::queue &opQueue, OpBuilder &b, ValueRange &loopState, const llvm::DenseMap &operandIdxMap, const llvm::SmallVector &inductionVars, - llvm::DenseMap &opPermuationMap) { - Operation *lastOperation = opQueue.front(); + const llvm::DenseMap &opPermuationMap) { + // Operation *lastOperation = opQueue.front(); while (!opQueue.empty()) { auto x = opQueue.front(); opQueue.pop(); - if (lastOperation == x) { - x->moveBefore(b.getBlock(), b.getBlock()->begin()); - } else { - x->moveAfter(lastOperation); - lastOperation = x; - } + // if (lastOperation == x) { + x->moveBefore(b.getBlock(), b.getBlock()->end()); + // } else { + // x->moveAfter(lastOperation); + // lastOperation = x; + // } // check operation type to set correct operand checkAndSetOperand(x, loopState, operandIdxMap, inductionVars, opPermuationMap); @@ -736,16 +734,15 @@ void moveOperationsToCurrentForBody( } scf::ForOp CanonicalizerVectorOperation::reductionAxisGenerateForLoop( - const int groupIdx, const size_t reductionIdx, ValueRange &initArgs, - llvm::SmallVector &inductionVars) { + OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, + ValueRange &initArgs, llvm::SmallVector &inductionVars) { MultiReductionCanonicalizer rdCanonicalizer = commonUsedData.getMultiRdCanonicalizer()[groupIdx]; auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; - OpBuilder opBuilder(multireductionOp); - auto loc = multireductionOp->getLoc(); + const auto loc = multireductionOp->getLoc(); auto &reductionAxis = rdCanonicalizer.getReductionAxis(); auto lastDimReduction = rdCanonicalizer.hasLastDimReduction(); - auto vectorType = multireductionOp.getSourceVectorType(); + auto vectorType = rdCanonicalizer.getSourceType(); const int loopStep = getDataTypeMAXSIMDLength(vectorType); auto isStandaloneOp = rdCanonicalizer.getIsStandAloneOp(); @@ -754,40 +751,79 @@ scf::ForOp CanonicalizerVectorOperation::reductionAxisGenerateForLoop( opBuilder, loc, (reductionIdx == reductionAxis.size() - 1 && lastDimReduction) ? loopStep : 1); - auto numIter = makeIndexArithConstantOp(opBuilder, loc, - vectorType.getShape()[reductionIdx]); + auto numIter = makeIndexArithConstantOp( + opBuilder, loc, vectorType.getShape()[reductionAxis[reductionIdx]]); auto forOp = opBuilder.create( loc, zero, numIter, forSteps, initArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { inductionVars.emplace_back(iv); if (reductionIdx == reductionAxis.size() - 1) { - IRRewriter rewriter(b); - IRMapping readMap; - Value reductionTarget = multireductionOp.getSource(); - llvm::SmallVector parallelAxis; - auto newReadOp = cloneReductionTransferRead( - reductionTarget, b, readMap, parallelAxis, inductionVars, - lastDimReduction, MultiReduceOpAxisKind::Reduction); if (isStandaloneOp) { + IRRewriter rewriter(b); + IRMapping readMap; + Value reductionTarget = multireductionOp.getSource(); + llvm::SmallVector parallelAxis; + auto newReadOp = cloneReductionTransferRead( + reductionTarget, b, readMap, parallelAxis, inductionVars, + lastDimReduction, MultiReduceOpAxisKind::Reduction); + auto reductionResult = + makeArithReduction(b, loc, multireductionOp.getKind(), + newReadOp->getResult(0), loopState.back()); + maybeYieldValue(b, loc, reductionResult); + } else { + auto &opPermuationMap = commonUsedData.getOpPermuationMap(); + auto &analysisResults = + commonUsedData.getGroupOpResults()[groupIdx]; + + auto &sourceOps = commonUsedData.getMultiRdCanonicalizer()[groupIdx] + .getSourceRelatedOps(); + auto &grpArgs = commonUsedData.getGroupOpIterArgs()[groupIdx]; + + rewriteOperationAsVectorize(b, groupIdx, sourceOps); + llvm::DenseMap operandIdxMap; + llvm::SmallVector resultArray; + // dummy + resultArray.emplace_back(Value()); + std::queue tmpSourceOps(sourceOps); + // move operation into current for loop body + // accVal is first loopstate + int start = 1; + while (!tmpSourceOps.empty()) { + auto cur = tmpSourceOps.front(); + tmpSourceOps.pop(); + auto curOperands = cur->getOperands(); + for (auto x : curOperands) { + if (grpArgs.contains(x)) { + operandIdxMap[x] = start++; + } + } + if (analysisResults.contains(cur->getResults()[0])) { + resultArray.emplace_back(cur->getResults()[0]); + commonUsedData.getMultiRdCanonicalizer()[groupIdx] + .getOriginalOpResults() + .insert(cur->getResults()[0]); + commonUsedData.getMultiRdCanonicalizer()[groupIdx] + .getResultIdxMap() + .insert({cur->getResults()[0], resultArray.size() - 1}); + } + } + moveOperationsToCurrentForBody(sourceOps, b, loopState, + operandIdxMap, inductionVars, + opPermuationMap); + + auto reductionResult = makeArithReduction( + b, loc, multireductionOp.getKind(), + multireductionOp.getSource(), loopState.back()); + resultArray[0] = reductionResult; + + maybeYieldValue(b, loc, resultArray); } - // reduction or elementwise reduce - // if (lastDimReduction) { - // Operation *reductionOp = rewriter.create( - // loc, multiReductionOp.getKind(), newReadOp->getResult(0), - // loopState.back()); - // maybeYieldValue(b, loc, reductionOp->getResults()); - // } else { - auto reductionResult = - makeArithReduction(b, loc, multireductionOp.getKind(), - newReadOp->getResult(0), loopState.back()); - maybeYieldValue(b, loc, reductionResult); - // } } else { // outter loop - auto nxtFor = reductionAxisGenerateForLoop(groupIdx, reductionIdx + 1, - loopState, inductionVars); + auto nxtFor = reductionAxisGenerateForLoop( + b, groupIdx, reductionIdx + 1, loopState, inductionVars); maybeYieldValue(b, loc, nxtFor->getResults()); } }); @@ -796,24 +832,24 @@ scf::ForOp CanonicalizerVectorOperation::reductionAxisGenerateForLoop( } scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( - const int groupIdx, const int parallelIdx, ValueRange &initArgs, - llvm::SmallVector &inductionVars, Value &originalWriteResult) { - MultiReductionCanonicalizer rdCanonicalizer = - commonUsedData.getMultiRdCanonicalizer()[groupIdx]; + OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, + ValueRange &initArgs, llvm::SmallVector &inductionVars, + Value &originalWriteResult) { + auto &rdCanonicalizer = commonUsedData.getMultiRdCanonicalizer()[groupIdx]; auto &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; - auto vectorType = multiReductionOp.getSourceVectorType(); + auto &vectorType = rdCanonicalizer.getSourceType(); + auto &accType = rdCanonicalizer.getAccType(); + auto ¶llelAxis = rdCanonicalizer.getParallelAxis(); auto isStandaloneOp = rdCanonicalizer.getIsStandAloneOp(); auto lastDimReduction = rdCanonicalizer.hasLastDimReduction(); - OpBuilder opBuilder(multiReductionOp); - auto loc = multiReductionOp.getLoc(); + const auto &loc = multiReductionOp.getLoc(); auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); auto forSteps = makeIndexArithConstantOp(opBuilder, loc, 1); // last dim reduction need to a generate dim=16 loop int dimSize = 0; - if (lastDimReduction && parallelIdx == parallelAxis.size() && - !isStandaloneOp) { + if (parallelIdx == parallelAxis.size()) { dimSize = 16; } else { dimSize = vectorType.getShape()[parallelAxis[parallelIdx]]; @@ -824,8 +860,8 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( loc, zero, numIter, forSteps, initArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { inductionVars.emplace_back(iv); - auto fusionStrategy = commonUsedData.getFusionStrategy(); - auto opIndexMap = fusionStrategy.getOpGroupIndexMap(); + auto &fusionStrategy = commonUsedData.getFusionStrategy(); + auto &opIndexMap = fusionStrategy.getOpGroupIndexMap(); assert(opIndexMap.contains(multiReductionOp) && " Must constains multireduction operation."); @@ -835,7 +871,7 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( auto &opPermuationMap = commonUsedData.getOpPermuationMap(); auto opQueue = opGroups[opIndex]; auto multiReductionAcc = multiReductionOp.getAcc(); - auto accType = mlir::dyn_cast(multiReductionAcc.getType()); + if (parallelIdx == parallelAxis.size() - 1) { // four kinds of group operations // If fused a operation, it means multirection must just @@ -916,8 +952,8 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( loc, DenseElementsAttr::get(accType, {initValueAttr})); ValueRange newIterArgs(accVal); - auto nxtFor = reductionAxisGenerateForLoop(groupIdx, 0, newIterArgs, - inductionVars); + auto nxtFor = reductionAxisGenerateForLoop( + b, groupIdx, 0, newIterArgs, inductionVars); // insert accumulate value to original vector auto accRes = nxtFor->getResults()[0]; @@ -949,61 +985,124 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( originalWriteResult = newAccWriteOp->getResult(0); maybeYieldValue(b, loc, newAccWriteOp->getResults()); + } else { + auto prevOp = opQueue.front(); + auto postOp = opQueue.back(); + auto &prevOps = + commonUsedData.getMultiRdCanonicalizer()[groupIdx].getPrevOps(); + auto &postOps = + commonUsedData.getMultiRdCanonicalizer()[groupIdx].getPostOps(); + auto &accRelatedOps = + commonUsedData.getMultiRdCanonicalizer()[groupIdx] + .getAccRelatedOps(); + auto &sourceRelatedOps = + commonUsedData.getMultiRdCanonicalizer()[groupIdx] + .getSourceRelatedOps(); + + if (mlir::isa(prevOp)) { + + } else { + if (mlir::isa(postOp)) { + // prevOp + reduction op + } else { + // prevOp + reduction op + postOp + // reduction op + postOp + getPrevOps(prevOps, opQueue, multiReductionOp); + getPostOps(postOps, opQueue, multiReductionOp); + // analysis acc related operation + classifySourceRelatedOps( + accRelatedOps, sourceRelatedOps, + multiReductionOp.getSource().getDefiningOp(), prevOps); + + rewriteOperationAsVectorize(b, groupIdx, accRelatedOps); + auto &grpArgs = commonUsedData.getGroupOpIterArgs()[groupIdx]; + llvm::DenseMap operandIdxMap; + for (auto [idx, x] : llvm::enumerate(grpArgs)) { + operandIdxMap[x] = idx; + } + moveOperationsToCurrentForBody(accRelatedOps, b, loopState, + operandIdxMap, inductionVars, + opPermuationMap); + auto &grpResults = commonUsedData.getGroupOpResults()[groupIdx]; + // next for loop + llvm::SmallVector iterArgsArray; + iterArgsArray.emplace_back(multiReductionAcc); + std::queue tmpSourceOps(sourceRelatedOps); + while (!tmpSourceOps.empty()) { + auto cur = tmpSourceOps.front(); + tmpSourceOps.pop(); + auto curResults = cur->getResults(); + for (auto x : curResults) { + if (grpResults.contains(x)) { + for (auto y : cur->getOperands()) { + if (grpArgs.contains(y)) { + iterArgsArray.emplace_back(y); + } + } + } + } + } + ValueRange reductionAxisArgs(iterArgsArray); + auto nxtFor = parallelAxisGenerateForLoop( + b, groupIdx, parallelIdx + 1, reductionAxisArgs, + inductionVars, originalWriteResult); + + rewriteOperationAsVectorize(b, groupIdx, postOps); + moveOperationsToCurrentForBody(postOps, b, loopState, + operandIdxMap, inductionVars, + opPermuationMap); + + auto replaceIfFn = [&](OpOperand &use) { + return use.getOwner()->getBlock() == nxtFor->getBlock(); + }; + rewriter.replaceOpUsesWithIf( + multiReductionOp, nxtFor->getResults()[0], replaceIfFn); + auto &originalResults = + commonUsedData.getMultiRdCanonicalizer()[groupIdx] + .getOriginalOpResults(); + for (auto [idx, x] : llvm::enumerate(originalResults)) { + rewriter.replaceOpUsesWithIf(x.getDefiningOp(), + nxtFor->getResults()[idx + 1], + replaceIfFn); + } + llvm::SmallVector resultsArray; + llvm::SmallDenseMap parallelIdxMap; + for (auto &x : grpResults) { + if (originalResults.contains(x)) { + auto &idxMap = rdCanonicalizer.getResultIdxMap(); + resultsArray.emplace_back(nxtFor->getResults()[idxMap[x]]); + } else { + resultsArray.emplace_back(x); + } + parallelIdxMap.insert({x, resultsArray.size() - 1}); + } + rdCanonicalizer.setResultIdxMap(parallelIdxMap); + maybeYieldValue(b, loc, resultsArray); + + // prepare iterArgs + } + } } - // else { - // auto prevOp = opQueue.front(); - // auto postOp = opQueue.back(); - - // if (mlir::isa(prevOp)) { - - // } else { - // if (mlir::isa(postOp)) { - // // prevOp + reduction op - // } else { - // // prevOp + reduction op + postOp - // // reduction op + postOp - - // getPrevOps(prevOps, opQueue, multiReductionOp); - // getPostOps(postOps, opQueue, multiReductionOp); - // // analysis acc related operation - // std::queue accRelatedOps, sourceRelatedOps; - // llvm::SmallVector iterArgsArray; - - // // prevOp need to classify - // classifySourceRelatedOps( - // accRelatedOps, sourceRelatedOps, - // multiReductionOp.getSource().getDefiningOp(), prevOps); - // rewriteOperationAsVectorize(prevOps, opIndexMap, b, - // opPermuationMap); - // moveOperationsToCurrentForBody(accRelatedOps, b, loopState, - // operandIdxMap, inductionVars, - // opPermuationMap); - // iterArgsArray.emplace_back(multiReductionAcc); - // ValueRange reductionAxisArgs(iterArgsArray); - // auto nxtFor = parallelAxisGenerateForLoop( - // b, multiReductionOp, parallelAxis, parallelIdx + 1, - // reductionAxis, reductionIdx, vectorType, inductionVars, - // loopState, operandIdxMap, originalWriteResult, - // lastDimReduction, loc, loopStep, canonicalizer, - // isStandaloneOp); - - // // prepare iterArgs - // } - // } - // } } else { - if (parallelIdx == parallelAxis.size() && !isStandaloneOp && - lastDimReduction) { + if (parallelIdx == parallelAxis.size() && !isStandaloneOp) { Attribute initValueAttr; getReductionInitAttr(multiReductionOp, initValueAttr); auto accVal = b.create( - loc, DenseElementsAttr::get(accType, {initValueAttr})); - ValueRange newIterArgs(accVal); - auto nxtFor = reductionAxisGenerateForLoop(groupIdx, 0, newIterArgs, - inductionVars); + loc, DenseElementsAttr::get(getVectorzedType(multiReductionOp), + {initValueAttr})); + llvm::SmallVector argsArray; + argsArray.emplace_back(accVal); + for (auto [idx, x] : llvm::enumerate(loopState)) { + if (idx == 0) + continue; + argsArray.emplace_back(x); + } + ValueRange newIterArgs(argsArray); + auto nxtFor = reductionAxisGenerateForLoop( + b, groupIdx, 0, newIterArgs, inductionVars); // insert accumulate value to original vector auto accRes = nxtFor->getResults()[0]; @@ -1011,12 +1110,24 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( loc, multiReductionOp.getKind(), accRes); auto insertOp = b.create( loc, reductionOp->getResult(0), initArgs[0], iv); - maybeYieldValue(b, loc, insertOp->getResults()); + auto insertResult = insertOp->getResults()[0]; + + // result + llvm::SmallVector retResults; + retResults.emplace_back(insertResult); + for (auto [idx, x] : llvm::enumerate(nxtFor->getResults())) { + if (idx == 0) { + continue; + } + retResults.emplace_back(x); + } + ValueRange retResultsArray(retResults); + maybeYieldValue(b, loc, retResultsArray); } else { - auto nxtFor = parallelAxisGenerateForLoop(groupIdx, parallelIdx + 1, - loopState, inductionVars, - originalWriteResult); + auto nxtFor = parallelAxisGenerateForLoop( + b, groupIdx, parallelIdx + 1, loopState, inductionVars, + originalWriteResult); maybeYieldValue(b, loc, nxtFor->getResults()); } } @@ -1026,20 +1137,27 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( scf::ForOp CanonicalizerVectorOperation::generateMultiReductionForLoop( const size_t grpIdx) { - auto &grpResults = commonUsedData.getGroupOpResults()[grpIdx]; - llvm::SmallVector forLoopArgs(grpResults.begin(), grpResults.end()); + auto &grpArgs = commonUsedData.getGroupOpIterArgs()[grpIdx]; + llvm::SmallVector forLoopArgs(grpArgs.begin(), grpArgs.end()); llvm::SmallVector inductionVars; ValueRange initArgs(forLoopArgs); Value originalWriteResult; + auto &rdCanonicalizer = commonUsedData.getMultiRdCanonicalizer()[grpIdx]; + auto &rdResultMap = rdCanonicalizer.getResultIdxMap(); + + OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); + + scf::ForOp forOp = parallelAxisGenerateForLoop( + opBuilder, grpIdx, 0, initArgs, inductionVars, originalWriteResult); - scf::ForOp forOp = parallelAxisGenerateForLoop(0, 0, initArgs, inductionVars, - originalWriteResult); auto replaceIfFn = [&](OpOperand &use) { - return use.getOwner()->getBlock() != - originalWriteResult.getDefiningOp()->getBlock(); + return use.getOwner()->getBlock() == forOp->getBlock(); }; - rewriter.replaceOpUsesWithIf(originalWriteResult.getDefiningOp(), - forOp->getResults()[0], replaceIfFn); + for (auto &grpResult : commonUsedData.getGroupOpResults()[grpIdx]) { + rewriter.replaceOpUsesWithIf(grpResult.getDefiningOp(), + forOp->getResults()[rdResultMap[grpResult]], + replaceIfFn); + } rewriter.replaceOp( commonUsedData.getMultiRdCanonicalizer()[grpIdx].getCandidateOps()[0], @@ -1047,8 +1165,8 @@ scf::ForOp CanonicalizerVectorOperation::generateMultiReductionForLoop( return forOp; } -llvm::SmallVector & -MultiReductionCanonicalizer::getCandidateOps() { +template +llvm::SmallVector &SpecialOperationCanonicalizer::getCandidateOps() { return candidateRdOps; }; @@ -1087,7 +1205,7 @@ void CanonicalizerVectorOperation::getCandidateSpecialOps() { void MultiReductionCanonicalizer::initReductionAxis() { auto reductionAxisRange = - candidateRdOps[0].getReductionDims().getAsValueRange(); + getCandidateOps()[0].getReductionDims().getAsValueRange(); auto reductionRange = llvm::to_vector<4>(llvm::map_range( reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); reductionAxis.assign(reductionRange.begin(), reductionRange.end()); @@ -1105,8 +1223,7 @@ void MultiReductionCanonicalizer::initParallelAxis() { } int64_t MultiReductionCanonicalizer::getTypeRank() { - auto srcVecType = candidateRdOps[0].getSourceVectorType(); - auto srcRank = srcVecType.getRank(); + auto srcRank = sourceType.getRank(); typeRank = srcRank; return srcRank; } @@ -1127,68 +1244,76 @@ bool MultiReductionCanonicalizer::hasLastDimReduction() { return res; } -void MultiReductionCanonicalizer::prepareReductionInfo() { +void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { + if (getCandidateOps().empty()) { + return; + } + sourceType = getCandidateOps()[0].getSourceVectorType(); + accType = mlir::dyn_cast(getCandidateOps()[0].getAcc().getType()); getTypeRank(); getReductionAxisAndParallelAxis(); hasLastDimReduction(); }; +template void addDummyInit(llvm::SmallVector &canonicalizer) { + canonicalizer.emplace_back(T({})); +}; + +void CanonicalizerCommonUsedData::initSpeicalOperationCanonicalizers() { + + broadcastCanonicalizers.clear(); + multiRdCanonicalizers.clear(); + transposeCanonicalizers.clear(); + auto &opGroups = fusionStrategy.getOpGroups(); + for (auto &grp : opGroups) { + addDummyInit(multiRdCanonicalizers); + addDummyInit(broadcastCanonicalizers); + addDummyInit(transposeCanonicalizers); + + if (grp.empty()) { + continue; + } + std::queue tempQ(grp); + while (!tempQ.empty()) { + auto op = tempQ.front(); + tempQ.pop(); + if (mlir::isa(op)) { + multiRdCanonicalizers.back().getCandidateOps().emplace_back( + mlir::dyn_cast(op)); + } else if (mlir::isa(op)) { + broadcastCanonicalizers.back().getCandidateOps().emplace_back( + mlir::dyn_cast(op)); + } else if (mlir::isa(op)) { + transposeCanonicalizers.back().getCandidateOps().emplace_back( + mlir::dyn_cast(op)); + } + } + // todo + multiRdCanonicalizers.back().prepareSpecialOperationInfo(); + } +} + LogicalResult CanonicalizerVectorOperation::canonicalizeReductionOperation() { OpBuilder::InsertionGuard guard(rewriter); + commonUsedData.initSpeicalOperationCanonicalizers(); // traverse all groups - auto &multiRdCanonicalizer = commonUsedData.getMultiRdCanonicalizer(); + auto &multiRdCanonicalizers = commonUsedData.getMultiRdCanonicalizer(); for (auto [groupId, rdCanonicalizer] : - llvm::enumerate(multiRdCanonicalizer)) { + llvm::enumerate(multiRdCanonicalizers)) { auto &candidateOps = rdCanonicalizer.getCandidateOps(); if (candidateOps.empty()) { continue; } // generate MultiReduction for loops - auto forOp = generateMultiReductionForLoop(groupId); - // update uses - } - // Separate reduction and parallel dims - // Operation *newReduction; - // auto accSourceOp = multiReductionAcc.getDefiningOp(); - // llvm::SmallVector initIterArgs; - // // process Acc operand - // if (mlir::dyn_cast(accSourceOp)) { - // auto accTensorReadOp = - // multiReductionAcc.getDefiningOp(); - // initIterArgs.emplace_back(accTensorReadOp->getOperand(0)); - // } - // auto dstOperandSet = commonUsedData.getGroupOpIterArgs()[grpIdx]; - // llvm::SmallVector operands; - // llvm::DenseMap operandIdxMap; - // for (auto [idx, x] : llvm::enumerate(dstOperandSet)) { - // initIterArgs.emplace_back(x); - // operandIdxMap[x] = operands.size() - 1; - // } - - // Value originalWriteResult; - // ValueRange iterArgs(initIterArgs); - // llvm::SmallVector inductionVars; - // auto forOp = generateMultiReductionForLoop( - // rewriter, multiReductionOp, parallelAxis, 0, reductionAxis, 0, - // srcVecType, inductionVars, iterArgs, operandIdxMap, - // originalWriteResult, *this, lastDimReduction, isStandaloneOp); - // auto replaceIfFn = [&](OpOperand &use) { - // return use.getOwner()->getBlock() != - // originalWriteResult.getDefiningOp()->getBlock(); - // }; - // newReduction = forOp; - // rewriter.replaceOpUsesWithIf(originalWriteResult.getDefiningOp(), - // newReduction->getResults()[0], replaceIfFn); - - // rewriter.replaceOp(firstOp, newReduction); + // (void)generateMultiReductionForLoop(groupId); + } return success(); } void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { // multireduction operation auto result = canonicalizeReductionOperation(); - // canonicalizeBroadCastOperation(); } void CanonicalizerVectorOperation::run() { @@ -1232,10 +1357,8 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. analysisGroupOperaionOperandsResults(); - std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" - << " : " << fusionStrategy.getOpGroups().size() << std::endl; // Speical Operation Canonicalization - // canonicalizeSpecialOperation(); + canonicalizeSpecialOperation(); // 2.Generate vectorized IR for each operation group for (auto [idx, grp] : llvm::enumerate(fusionStrategy.getOpGroups())) { @@ -1291,13 +1414,15 @@ void checkAndSetOperand( if (llvm::dyn_cast(op) || llvm::dyn_cast(op)) { assert(opPermuationMap.contains(op)); - + std::cout << "op verify ..." << std::endl; + op->dump(); auto permutationMap = opPermuationMap.at(op); auto dimExpr = permutationMap.getResults(); for (auto [idx, x] : llvm::enumerate(dimExpr)) { if (mlir::dyn_cast(x)) { auto dim = mlir::dyn_cast(x).getPosition(); + std::cout << inductionVars.size() << "," << dim << std::endl; op->setOperand(dim + offset, inductionVars[dim]); } } @@ -1312,18 +1437,11 @@ scf::ForOp constructNestedForOp( const llvm::DenseMap &operandIdxMap, const llvm::DenseMap &opPermuationMap) { const int loop_step = getDataTypeMAXSIMDLength(type); - // loop initialization variable - auto zero = - b.create(b.getUnknownLoc(), b.getIndexType(), - b.getIntegerAttr(b.getIndexType(), 0)); - auto forSteps = b.create( - b.getUnknownLoc(), b.getIndexType(), - b.getIntegerAttr(b.getIndexType(), - idx == dims.size() - 1 ? loop_step : 1)); - auto numIter = b.create( - b.getUnknownLoc(), b.getIndexType(), - b.getIntegerAttr(b.getIndexType(), dims[idx])); + auto zero = makeIndexArithConstantOp(b, loc, 0); + auto forSteps = + makeIndexArithConstantOp(b, loc, idx == dims.size() - 1 ? loop_step : 1); + auto numIter = makeIndexArithConstantOp(b, loc, dims[idx]); // Create a loop and move vectorized operation into loops. auto forOp = b.create( @@ -1333,23 +1451,10 @@ scf::ForOp constructNestedForOp( // inner most body of the loop if (idx == dims.size() - 1) { - Operation *lastOperation = queue.front(); - while (!queue.empty()) { - auto x = queue.front(); - queue.pop(); - if (lastOperation == x) { - x->moveBefore(b.getBlock(), b.getBlock()->begin()); - } else { - x->moveAfter(lastOperation); - lastOperation = x; - } - // check operation type to set correct operand - checkAndSetOperand(x, loopState, operandIdxMap, inductionVars, - opPermuationMap); - } + moveOperationsToCurrentForBody(queue, b, loopState, operandIdxMap, + inductionVars, opPermuationMap); maybeYieldValue(b, loc, resultSet.getArrayRef()); } else { - // outter loop auto nxtFor = constructNestedForOp( b, loc, loopState, type, dims, idx + 1, queue, resultSet, @@ -1366,8 +1471,23 @@ bool isCompatibleVectorType(Operation *op1, Operation *op2) { if (failed(type1) || failed(type2)) { return false; } + auto isReadOrWrite = [](Operation *op) { + return mlir::isa(op) or + mlir::isa(op); + }; auto sp1 = type1.value(); auto sp2 = type2.value(); + if (isReadOrWrite(op1) or isReadOrWrite(op2)) { + if (sp1.getRank() != sp2.getRank()) { + return false; + } + for (long i = 0; i < sp1.getRank(); i++) { + if (sp1.getDimSize(i) != sp2.getDimSize(i)) { + return false; + } + } + } + auto min_rank = std::min(sp1.getRank(), sp2.getRank()) - 1; bool isCompatible = true; // from front to back @@ -1427,6 +1547,37 @@ void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, assert(i == rankA && j == rankB && "Invalid shapecast operation."); } +bool isScalar(Type type) { + assert(type && "Not a valid type"); + if (auto vecType = dyn_cast(type)) + return false; + if (auto tensorType = dyn_cast(type)) + return false; + return true; +} + +void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output, + llvm::SmallVector &bcAxis) { + auto inputShape = input.getShape(); + auto outputShape = output.getShape(); + // following auto_broadcast semantics + const size_t input_rank = inputShape.size(); + const size_t output_rank = outputShape.size(); + assert(output_rank >= input_rank && + "Incorrect input or output shape for broadcast op."); + const size_t offset = output_rank - input_rank; + for (size_t i = 0; i < input_rank; ++i) { + if (inputShape[i] == outputShape[i + offset] || + (ShapedType::isDynamic(inputShape[i]) && + ShapedType::isDynamic(outputShape[i + offset]))) { + bcAxis.emplace_back(i); + } + } + if (bcAxis.empty()) { + bcAxis.emplace_back(-1); + } +} + void getOperationDataAxis(Operation *op, llvm::SmallVector &dataAxis) { return TypeSwitch(op) .Case( @@ -1448,6 +1599,27 @@ void getOperationDataAxis(Operation *op, llvm::SmallVector &dataAxis) { shapeCastSourceAxis(dstShape, srcShape, dataAxis); } }) + .Case([&](vector::BroadcastOp broadcastOp) { + auto srcType = broadcastOp.getSourceType(); + auto dstType = broadcastOp.getResultVectorType(); + if (isScalar(srcType)) { + dataAxis.emplace_back(0); + } else { + auto inputType = mlir::cast(srcType); + auto outputType = mlir::cast(dstType); + getSrcBroadcastDim(inputType, outputType, dataAxis); + } + }) + .Case([&](vector::TransposeOp transposeOp) { + auto perm = transposeOp.getPermutation(); + int start = 0; + for (auto x : perm) { + if (x != start) { + dataAxis.emplace_back(x); + } + start++; + } + }) .Default([&](Operation *op) { // default is last axis dataAxis.emplace_back( @@ -1464,7 +1636,8 @@ bool hasDataDependency(Operation *op1, Operation *op2) { if (!isSpecialOp(op1)) { return hasDataDependency(op2, op1); } - if (isSpecialOp(op1)) { + // TODO: remove this condition in the future + if (disableSpecialOp) { return true; } auto hasSameAxis = [](const llvm::SmallVector &dims1, @@ -1485,30 +1658,80 @@ bool hasDataDependency(Operation *op1, Operation *op2) { getOperationDataAxis(op2, dims2); return hasSameAxis(dims1, dims2); }) - .Case<>([&](vector::MultiDimReductionOp multiReductionOp) { - // op1 is special operation, op2 is normal operation - // op1 and op2 is both speicial operation - auto rdDimsRange = multiReductionOp.getReductionDims() - .getAsValueRange(); - auto reductionDims = llvm::to_vector( - llvm::map_range(rdDimsRange, [](const APInt &a) { - return (int64_t)a.getZExtValue(); - })); - llvm::SmallVector dims2; - getOperationDataAxis(op2, dims2); - llvm::DenseSet checkSet(dims2.begin(), dims2.end()); + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + // op1 is special operation, op2 is normal operation + // op1 and op2 is both speicial operation + llvm::SmallVector dims2, reductionDims, parallelDims; + getOperationDataAxis(op1, reductionDims); + getOperationDataAxis(op2, dims2); + llvm::DenseSet checkSet(dims2.begin(), dims2.end()); + + if (!isSpecialOp(op2)) { + // all reduction axis should be op2's data axis + bool reduceDependent = false; + for (auto x : reductionDims) { + if (!checkSet.contains(x)) { + reduceDependent = true; + break; + } + } + if (!reduceDependent) { + return false; + } + // all parallel axis should equal to op2's axis + checkSet.clear(); + checkSet.insert(reductionDims.begin(), reductionDims.end()); + auto rdRank = + multiReductionOp.getSourceVectorType().getRank(); + for (auto i = 0; i < rdRank; i++) { + if (!checkSet.contains(i)) { + parallelDims.emplace_back(i); + } + } + checkSet.clear(); + checkSet.insert(parallelDims.begin(), parallelDims.end()); + auto rank = + mlir::dyn_cast(op2->getResultTypes()[0]) + .getRank(); + for (auto i = 0; i < rank; i++) { + if (!checkSet.contains(i)) { + return true; + } + } - if (!isSpecialOp(op2)) { - for (auto x : reductionDims) { - if (!checkSet.contains(x)) { - return true; + return false; + } else { + // TODO: reduce operation fused with other special operation + if (mlir::isa(op2)) { + return true; + } else if (mlir::isa(op2)) { + return true; + } + //... } - } + + return true; + }) + .Case([&](vector::BroadcastOp broadcastOp) { + llvm::SmallVector dims1, dims2; + getOperationDataAxis(op1, dims1); + getOperationDataAxis(op2, dims2); + if (!isSpecialOp(op2)) { + return hasSameAxis(dims1, dims2); } else { - // TODO: reduce operation fused with other special operation } - - return false; + return true; + }) + .Case([&](vector::TransposeOp transposeOp) { + llvm::SmallVector dims1, dims2; + getOperationDataAxis(op1, dims1); + getOperationDataAxis(op2, dims2); + if (!isSpecialOp(op2)) { + return hasSameAxis(dims1, dims2); + } else { + } + return true; }) .Default([&](Operation *op) { return false; }); @@ -1573,6 +1796,19 @@ void VectorFusionStrategy::classifyOperations() { } }); }); + for (auto grp : opGroups) { + std::cout << " ____________________" << std::endl; + if (grp.empty()) { + continue; + } + std::queue tmpQ(grp); + while (!tmpQ.empty()) { + auto cur = tmpQ.front(); + tmpQ.pop(); + cur->dump(); + } + std::cout << "___________________" << std::endl; + } } Value setOutGroupOperationOperandResult(Operation *op, @@ -1641,8 +1877,10 @@ void createNewConstantOp( /// Rewrite the operations in the group to vectorized form. void CanonicalizerVectorOperation::rewriteOperationAsVectorize( - OpBuilder &rewriter, size_t groupId) { - auto &groupOps = commonUsedData.getFusionStrategy().getOpGroups()[groupId]; + OpBuilder &rewriter, size_t groupId, const std::queue &queue) { + auto &groupOps = + queue.empty() ? commonUsedData.getFusionStrategy().getOpGroups()[groupId] + : queue; auto &opMap = commonUsedData.getFusionStrategy().getOpGroupIndexMap(); auto &opPermuationMap = commonUsedData.getOpPermuationMap(); std::queue transformQueue(groupOps); @@ -1660,12 +1898,11 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( if (mlir::isa(srcOp)) { createNewConstantOp(srcOp, &transferWriteOp, opPermuationMap); - } else if (!isSpecialOp(srcOp)) { - - transferWriteOp->getOperand(0).setType(newOperandType); - + } else { opPermuationMap.insert( {transferWriteOp, transferWriteOp.getPermutationMap()}); + transferWriteOp->getOperand(0).setType(newOperandType); + setOpVectorizationPermutationMap( transferWriteOp, rewriter, mlir::dyn_cast( @@ -1678,24 +1915,14 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( .Case( [&](vector::TransferReadOp transferReadOp) { auto newOperandType = getVectorzedType(transferReadOp); - auto users = transferReadOp->getUsers(); - bool isUserSpecial = false; - for (auto *opUse : users) { - if (isSpecialOp(opUse)) { - isUserSpecial = true; - break; - } - } - if (!isUserSpecial) { - opPermuationMap.insert( - {transferReadOp, transferReadOp.getPermutationMap()}); - transferReadOp->getResult(0).setType(newOperandType); - setOpVectorizationPermutationMap( - transferReadOp, rewriter, - mlir::dyn_cast( - transferReadOp.getSource().getType()), - transferReadOp.getPermutationMap()); - } + opPermuationMap.insert( + {transferReadOp, transferReadOp.getPermutationMap()}); + transferReadOp->getResult(0).setType(newOperandType); + setOpVectorizationPermutationMap( + transferReadOp, rewriter, + mlir::dyn_cast( + transferReadOp.getSource().getType()), + transferReadOp.getPermutationMap()); return success(); }) @@ -1712,7 +1939,9 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( }) .Default([&](Operation *op) { if (isSpecialOp(op)) { - return success(); + llvm::llvm_unreachable_internal( + "It should not appear this operation."); + return failure(); } setOperationOperandResult(op, getVectorzedType(op), opMap); return success(); @@ -1771,6 +2000,25 @@ void updateOpOperandResultInGroups( void VectorFusionStrategy::run() { classifyOperations(); } +void CanonicalizerVectorOperation::generateEmptyTensorAndWrite( + Operation *sourceOp, llvm::DenseMap> + &srcOpCanoniclizedMap) { + auto &commonUsedData = getCommonUsedData(); + auto &opGroups = commonUsedData.getFusionStrategy().getOpGroups(); + auto &opGroupIndexMap = + commonUsedData.getFusionStrategy().getOpGroupIndexMap(); + auto &groupOpIterArgs = commonUsedData.getGroupOpIterArgs(); + auto &groupOpResults = commonUsedData.getGroupOpResults(); + auto sourceOpGid = opGroupIndexMap[sourceOp]; + + auto [resultTensor, result] = canonicalizeSourceOperation(sourceOp); + srcOpCanoniclizedMap.insert({sourceOp, {resultTensor, result}}); + updateOpOperandResultInGroups(opGroups, opGroupIndexMap, sourceOpGid, + sourceOp, resultTensor, result); + groupOpIterArgs[sourceOpGid].insert(resultTensor); + groupOpResults[sourceOpGid].insert(result); +} + // analysis operation result of current group whether needed by other // operation which out of current group void CanonicalizerVectorOperation::analysisGroupOperationResults() { @@ -1800,14 +2048,7 @@ void CanonicalizerVectorOperation::analysisGroupOperationResults() { if (failed(dstRet)) { // already generate result tensor if (!srcOpCanoniclizedMap.contains(sourceOp)) { - auto [resultTensor, result] = - canonicalizeSourceOperation(sourceOp); - srcOpCanoniclizedMap.insert({sourceOp, {resultTensor, result}}); - updateOpOperandResultInGroups(opGroups, opGroupIndexMap, - sourceOpGid, sourceOp, resultTensor, - result); - groupOpIterArgs[sourceOpGid].insert(resultTensor); - groupOpResults[sourceOpGid].insert(result); + generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap); } auto opInit = canonicalizeCurrentOperation( @@ -1816,8 +2057,8 @@ void CanonicalizerVectorOperation::analysisGroupOperationResults() { opGroupIndexMap[op], op, opInit); } else { - // if source operation is transfer_read, we need to generate a same - // transfer_read operation like source operation. + // if source operation is transfer_read, we need to generate a + // same transfer_read operation like source operation. if (mlir::isa(sourceOp)) { auto transferReadOp = mlir::dyn_cast(sourceOp); @@ -1834,6 +2075,13 @@ void CanonicalizerVectorOperation::analysisGroupOperationResults() { } } } + // reduce have reduction axis and parallel axis, reduction loop must write + // result back + // if (mlir::isa(op)) { + // if (!srcOpCanoniclizedMap.contains(op)) { + // generateEmptyTensorAndWrite(op, srcOpCanoniclizedMap); + // } + // } }); // If the group operations do not have result need to be returned, these are // useless code. @@ -1870,20 +2118,6 @@ mlir::FailureOr generateVectorizedForLoop( ValueRange iterArgs(operands); auto shapes = vectorType.getShape(); llvm::SmallVector inductionVars; - // TODO: special operation process - bool isOpSpecial = false; - std::queue tmpQ(queue); - // temporary for special operation generation - while (!tmpQ.empty()) { - if (isSpecialOp(tmpQ.front())) { - isOpSpecial = true; - break; - } - tmpQ.pop(); - } - if (isOpSpecial) { - return failure(); - } // generate for loop auto forOp = constructNestedForOp( rewriter, rewriter.getUnknownLoc(), iterArgs, vectorType, shapes, 0, @@ -1911,16 +2145,14 @@ void updateLoopResultUses(llvm::SetVector &opResults, } } -bool hasSpecialOperation(std::queue &grp) { - std::queue tmpQ(grp); - while (!tmpQ.empty()) { - auto curOp = tmpQ.front(); - if (isSpecialOp(curOp)) { - return true; - } - tmpQ.pop(); - } - return false; +bool CanonicalizerVectorOperation::isGroupHasSpecialOperation( + const size_t grpIdx) { + auto &rdCanonicalizer = commonUsedData.getMultiRdCanonicalizer()[grpIdx]; + auto &bcCanonicalizer = commonUsedData.getBroadcastCanonicalizer()[grpIdx]; + auto &tpCanonicalizer = commonUsedData.getTransposeCanonicalizer()[grpIdx]; + return !rdCanonicalizer.getCandidateOps().empty() or + !bcCanonicalizer.getCandidateOps().empty() or + !tpCanonicalizer.getCandidateOps().empty(); } void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { @@ -1930,7 +2162,7 @@ void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { return; } // TODO: special operation better fusion - if (hasSpecialOperation(grp)) { + if (isGroupHasSpecialOperation(idx)) { return; } auto &groupOpResults = commonUsedData.getGroupOpResults(); @@ -1947,14 +2179,6 @@ void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { // 1. Rewrite operation as vectorized form rewriteOperationAsVectorize(rewriter, idx); // 2. Generate loop - // 2.a more init operation before current group operations - // auto firstGrpOp = grp.front(); - // while (!fusionStrategy.getIgnoreInitOperations()[idx].empty()) { - // auto initOp = fusionStrategy.getIgnoreInitOperations()[idx].front(); - // initOp->moveBefore(firstGrpOp); - // fusionStrategy.getIgnoreInitOperations()[idx].pop(); - // } - // 2.b generate common outter for loop auto forOp = generateVectorizedForLoop(rewriter, groupOpResults[idx], groupOpIterArgs[idx], opShapes, grp, opPermuationMap); diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index 42dd19bd0..c55674542 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -471,7 +471,7 @@ struct LowerTileVectorPass populateLowerToTileVectorPatterns(patterns); linalg::populatePadOpVectorizationPatterns(patterns); - // vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); vector::populateSinkVectorBroadcastPatterns(patterns); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); From b01472a679ad2006a21e638d6edf738ea18e52a1 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 5 Jul 2024 16:23:38 +0800 Subject: [PATCH 11/66] fix tests --- include/gc/Transforms/TilingVector.h | 20 ++ lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 254 ++++++++++++------ lib/gc/Transforms/LowerTileVectorPass.cpp | 218 ++++++++------- 3 files changed, 316 insertions(+), 176 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index a280999d2..6f54d8a4b 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -78,6 +78,7 @@ void checkAndSetOperand( class VectorFusionStrategy { private: llvm::SmallVector, 8> opGroups; + llvm::SmallVector groupMaxSteps; // query current operation in which group, return group index llvm::DenseMap opGroupIndexMap; // can fused into prev operation which axis position @@ -94,6 +95,7 @@ class VectorFusionStrategy { llvm::DenseMap &getOpGroupIndexMap() { return opGroupIndexMap; } + llvm::SmallVector &getGroupMaxSteps() { return groupMaxSteps; } func::FuncOp getFunc() { return func; } llvm::SmallVector, 8> getIgnoreInitOperations() { @@ -186,6 +188,16 @@ class TransposeCanonicalizer void prepareSpecialOperationInfo() override {} }; +class ShapeCastCanonicalizer + : virtual public SpecialOperationCanonicalizer { +private: +public: + ShapeCastCanonicalizer( + const llvm::SmallVector &candidateScOps) + : SpecialOperationCanonicalizer(candidateScOps){}; + void prepareSpecialOperationInfo() override {} +}; + class CanonicalizerCommonUsedData { private: VectorFusionStrategy fusionStrategy; @@ -198,6 +210,7 @@ class CanonicalizerCommonUsedData { llvm::SmallVector multiRdCanonicalizers; llvm::SmallVector broadcastCanonicalizers; llvm::SmallVector transposeCanonicalizers; + llvm::SmallVector shapeCastCanonicalizers; public: CanonicalizerCommonUsedData() = default; @@ -265,6 +278,10 @@ class CanonicalizerCommonUsedData { return transposeCanonicalizers; } + llvm::SmallVector &getShapeCastCanonicalizer() { + return shapeCastCanonicalizers; + } + // other methods void initSpeicalOperationCanonicalizers(); }; @@ -296,6 +313,7 @@ class CanonicalizerVectorOperation { void generateGroupOpVectorizedIR(const int idx); + void analysisEmptyGroupAndMaxSteps(); void analysisGroupOperaionOperandsResults(); void generateEmptyTensorAndWrite( @@ -308,6 +326,8 @@ class CanonicalizerVectorOperation { IRRewriter &rewriter); void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, const std::queue &queue = {}); + void createNewConstantOp(Operation *srcOp, + vector::TransferWriteOp *transferWriteOp); // special operation methods scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 008051245..f98758c07 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -25,6 +25,20 @@ struct HardwareInfo { bool favx2 = true; } HW; +void printGroupOps(llvm::SmallVector, 8> &opGroups) { + for (auto grp : opGroups) { + if (grp.empty()) { + continue; + } + std::queue tmpQ(grp); + while (!tmpQ.empty()) { + auto cur = tmpQ.front(); + tmpQ.pop(); + cur->dump(); + } + } +} + bool isSpecialOp(Operation *op) { return llvm::isa(op) || llvm::isa(op) || @@ -37,7 +51,7 @@ bool isSpecialOp(Operation *op) { bool is_innermost_operation(Operation *op) { bool inner_most = true; op->walk([&inner_most](Operation *p) { - if (llvm::isa(p)) { + if (mlir::isa(p)) { inner_most = false; return WalkResult::interrupt(); } @@ -47,7 +61,7 @@ bool is_innermost_operation(Operation *op) { } int generateValidSteps(int steps, VectorType type) { - return type.getShape().back() >= steps ? steps > 16 ? 16 : steps : 1; + return type.getShape().back() >= steps ? (steps > 16 ? 16 : steps) : steps; } // expr equals `vector rank` - 1 @@ -57,14 +71,16 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { } // Get the maximum number of current data types that a register can hold -[[nodiscard]] int getDataTypeMAXSIMDLength(VectorType type) { +[[nodiscard]] int getDataTypeMAXSIMDLength(const VectorType &type) { auto typebits = type.getElementTypeBitWidth(); const int favx512bits = 512; const int favx2bits = 256; if (HW.favx512f) { - return generateValidSteps(favx512bits / typebits, type); + // return generateValidSteps(favx512bits / typebits, type); + return favx512bits / typebits; } else if (HW.favx2) { - return generateValidSteps(favx2bits / typebits, type); + // return generateValidSteps(favx2bits / typebits, type); + return favx2bits / typebits; } else { // invalid LDBG("Please check the hardware information."); @@ -74,6 +90,9 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { } mlir::FailureOr getOperationVectorType(Operation *op) { + if (!op) { + return failure(); + } return TypeSwitch>(op) .Case( [&](vector::TransferWriteOp transferWriteOp) @@ -107,7 +126,7 @@ mlir::FailureOr getOperationVectorType(Operation *op) { }); } -VectorType getVectorzedType(Operation *op) { +VectorType getVectorzedType(Operation *op, uint32_t loop_step = 0) { // Check that the operation type can be broken // down into a loop. auto baseType = getOperationVectorType(op); @@ -117,7 +136,9 @@ VectorType getVectorzedType(Operation *op) { return VectorType(); } auto vectorizedType = baseType.value(); - const int loop_step = getDataTypeMAXSIMDLength(vectorizedType); + if (loop_step == 0) { + loop_step = getDataTypeMAXSIMDLength(vectorizedType); + } return VectorType::get({loop_step}, vectorizedType.getElementType()); } @@ -1264,11 +1285,13 @@ void CanonicalizerCommonUsedData::initSpeicalOperationCanonicalizers() { broadcastCanonicalizers.clear(); multiRdCanonicalizers.clear(); transposeCanonicalizers.clear(); + shapeCastCanonicalizers.clear(); auto &opGroups = fusionStrategy.getOpGroups(); for (auto &grp : opGroups) { addDummyInit(multiRdCanonicalizers); addDummyInit(broadcastCanonicalizers); addDummyInit(transposeCanonicalizers); + addDummyInit(shapeCastCanonicalizers); if (grp.empty()) { continue; @@ -1280,16 +1303,18 @@ void CanonicalizerCommonUsedData::initSpeicalOperationCanonicalizers() { if (mlir::isa(op)) { multiRdCanonicalizers.back().getCandidateOps().emplace_back( mlir::dyn_cast(op)); + multiRdCanonicalizers.back().prepareSpecialOperationInfo(); } else if (mlir::isa(op)) { broadcastCanonicalizers.back().getCandidateOps().emplace_back( mlir::dyn_cast(op)); } else if (mlir::isa(op)) { transposeCanonicalizers.back().getCandidateOps().emplace_back( mlir::dyn_cast(op)); + } else if (mlir::isa(op)) { + shapeCastCanonicalizers.back().getCandidateOps().emplace_back( + mlir::dyn_cast(op)); } } - // todo - multiRdCanonicalizers.back().prepareSpecialOperationInfo(); } } @@ -1782,56 +1807,63 @@ void VectorFusionStrategy::classifyOperations() { opGroups.emplace_back(std::queue()); } func->walk([&](Operation *op) { - TypeSwitch(op).Default([&](Operation *op) { - if (filterOperation(op)) { - addOperationToGroup(opGroups, opGroupIndexMap, op); - // update init operation - } - while (ignoreInitOperations.size() < opGroups.size()) { - ignoreInitOperations.emplace_back(std::queue()); - } - // some init operations need to ignore - if (isInitOperation(op)) { - ignoreInitOperations.back().push(op); - } - }); - }); - for (auto grp : opGroups) { - std::cout << " ____________________" << std::endl; - if (grp.empty()) { - continue; + if (filterOperation(op)) { + addOperationToGroup(opGroups, opGroupIndexMap, op); + // update init operation } - std::queue tmpQ(grp); - while (!tmpQ.empty()) { - auto cur = tmpQ.front(); - tmpQ.pop(); - cur->dump(); + while (ignoreInitOperations.size() < opGroups.size()) { + ignoreInitOperations.emplace_back(std::queue()); } - std::cout << "___________________" << std::endl; - } + // some init operations need to ignore + if (isInitOperation(op)) { + ignoreInitOperations.back().push(op); + } + }); } Value setOutGroupOperationOperandResult(Operation *op, const VectorType &newOperandType) { - auto ret = TypeSwitch(op) - .Case([&](arith::ConstantOp constantOp) { - IRRewriter rewriter(op); - rewriter.setInsertionPointAfter(op); - Type resultElementType = newOperandType.getElementType(); - - Attribute initValueAttr; - if (isa(resultElementType)) { - initValueAttr = FloatAttr::get(resultElementType, 0.0); - - } else { - initValueAttr = IntegerAttr::get(resultElementType, 0); - } - auto cntOp = rewriter.create( - rewriter.getUnknownLoc(), - DenseElementsAttr::get(newOperandType, {initValueAttr})); - return cntOp->getResults()[0]; - }) - .Default([&](Operation *op) { return Value(); }); + auto ret = + TypeSwitch(op) + .Case([&](arith::ConstantOp constantOp) { + IRRewriter rewriter(op); + rewriter.setInsertionPointAfter(op); + Type resultElementType = newOperandType.getElementType(); + auto value = constantOp.getValue(); + Attribute initValueAttr; + + if (mlir::isa(value)) { + auto valueType = mlir::dyn_cast(value); + if (valueType.isSplat()) { + if (mlir::isa(valueType.getElementType())) { + initValueAttr = FloatAttr::get( + resultElementType, + valueType.getSplatValue().convertToDouble()); + } else { + initValueAttr = IntegerAttr::get( + resultElementType, + valueType.getSplatValue().getSExtValue()); + } + } else { + // write original vector into tensor + // then we transfer_read from the tensor + assert(0 && "Not support non-splat constant value."); + } + } else if (isa(resultElementType)) { + initValueAttr = FloatAttr::get( + resultElementType, + llvm::cast(value).getValueAsDouble()); + } else { + initValueAttr = IntegerAttr::get( + resultElementType, llvm::cast(value).getInt()); + } + + auto cntOp = rewriter.create( + rewriter.getUnknownLoc(), + DenseElementsAttr::get(newOperandType, {initValueAttr})); + return cntOp->getResults()[0]; + }) + .Default([&](Operation *op) { return Value(); }); return ret; } @@ -1856,14 +1888,36 @@ void setOperationOperandResult( } }; -void createNewConstantOp( - Operation *srcOp, vector::TransferWriteOp *transferWriteOp, - llvm::DenseMap &opPermuationMap) { +void CanonicalizerVectorOperation::createNewConstantOp( + Operation *srcOp, vector::TransferWriteOp *transferWriteOp) { + auto &opPermuationMap = commonUsedData.getOpPermuationMap(); IRRewriter srcWriter(srcOp); auto newOperandType = getVectorzedType(mlir::cast(srcOp)); auto srcConstantOp = mlir::dyn_cast(srcOp); - Operation *newConstantOp = srcWriter.create( - srcOp->getLoc(), srcConstantOp.getValueAttr()); + Operation *newConstantOp; + if (mlir::isa(srcConstantOp.getValue())) { + auto valueType = mlir::dyn_cast(srcConstantOp.getValue()); + if (valueType.isSplat()) { + if (mlir::isa(valueType.getElementType())) { + newConstantOp = srcWriter.create( + srcOp->getLoc(), + FloatAttr::get(newOperandType, valueType.getSplatValue())); + } else { + newConstantOp = srcWriter.create( + srcOp->getLoc(), + IntegerAttr::get(newOperandType, valueType.getSplatValue())); + } + + } else { + // write original vector into tensor + // then we transfer_read from the tensor + assert(0 && "Not support non-splat constant value."); + } + } else { + newConstantOp = srcWriter.create( + srcOp->getLoc(), srcConstantOp.getValue()); + } + newConstantOp->getResult(0).setType(newOperandType); transferWriteOp->setOperand(0, newConstantOp->getResult(0)); opPermuationMap.insert( @@ -1884,6 +1938,8 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( auto &opMap = commonUsedData.getFusionStrategy().getOpGroupIndexMap(); auto &opPermuationMap = commonUsedData.getOpPermuationMap(); std::queue transformQueue(groupOps); + auto groupSteps = + commonUsedData.getFusionStrategy().getGroupMaxSteps()[groupId]; while (!transformQueue.empty()) { auto op = transformQueue.front(); @@ -1893,14 +1949,16 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( .Case( [&](vector::TransferWriteOp transferWriteOp) { IRRewriter rewriter(transferWriteOp); - auto newOperandType = getVectorzedType(transferWriteOp); + auto newOperandType = + getVectorzedType(transferWriteOp, groupSteps); auto srcOp = transferWriteOp->getOperand(0).getDefiningOp(); if (mlir::isa(srcOp)) { - createNewConstantOp(srcOp, &transferWriteOp, - opPermuationMap); + createNewConstantOp(srcOp, &transferWriteOp); } else { + transferWriteOp->dump(); opPermuationMap.insert( {transferWriteOp, transferWriteOp.getPermutationMap()}); + newOperandType.dump(); transferWriteOp->getOperand(0).setType(newOperandType); setOpVectorizationPermutationMap( @@ -1914,7 +1972,8 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( }) .Case( [&](vector::TransferReadOp transferReadOp) { - auto newOperandType = getVectorzedType(transferReadOp); + auto newOperandType = + getVectorzedType(transferReadOp, groupSteps); opPermuationMap.insert( {transferReadOp, transferReadOp.getPermutationMap()}); transferReadOp->getResult(0).setType(newOperandType); @@ -1933,7 +1992,7 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( return failure(); }) .Case([&](arith::ExtFOp extFop) { - auto newOperandType = getVectorzedType(extFop); + auto newOperandType = getVectorzedType(extFop, groupSteps); extFop->getResult(0).setType(newOperandType); return success(); }) @@ -1943,7 +2002,8 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( "It should not appear this operation."); return failure(); } - setOperationOperandResult(op, getVectorzedType(op), opMap); + setOperationOperandResult(op, getVectorzedType(op, groupSteps), + opMap); return success(); }); if (failed(lowerResult)) { @@ -2019,10 +2079,45 @@ void CanonicalizerVectorOperation::generateEmptyTensorAndWrite( groupOpResults[sourceOpGid].insert(result); } +void CanonicalizerVectorOperation::analysisEmptyGroupAndMaxSteps() { + auto &groupOpResults = commonUsedData.getGroupOpResults(); + auto &opGroups = commonUsedData.getFusionStrategy().getOpGroups(); + + // If the group operations do not have result need to be returned, these are + // useless code. + for (auto [idx, grp] : enumerate(opGroups)) { + if (groupOpResults[idx].empty()) { + std::queue().swap(grp); + } + uint32_t steps = std::numeric_limits::max(); + + auto &grpSteps = commonUsedData.getFusionStrategy().getGroupMaxSteps(); + while (idx >= grpSteps.size()) { + grpSteps.emplace_back(steps); + } + std::queue tmpQueue(grp); + auto calculateOpSteps = [&](Type type) { + auto opType = mlir::dyn_cast(type); + if (opType) + steps = std::min(steps, (uint32_t)getDataTypeMAXSIMDLength(opType)); + }; + while (!tmpQueue.empty()) { + auto op = tmpQueue.front(); + tmpQueue.pop(); + if (mlir::isa(op)) { + calculateOpSteps(op->getOperandTypes()[0]); + } + calculateOpSteps(op->getResultTypes()[0]); + } + grpSteps[idx] = steps; + } +} + // analysis operation result of current group whether needed by other // operation which out of current group void CanonicalizerVectorOperation::analysisGroupOperationResults() { llvm::DenseMap> srcOpCanoniclizedMap; + llvm::DenseSet movedOperationSet; auto &commonUsedData = getCommonUsedData(); auto &opGroups = commonUsedData.getFusionStrategy().getOpGroups(); auto &opGroupIndexMap = @@ -2075,29 +2170,17 @@ void CanonicalizerVectorOperation::analysisGroupOperationResults() { } } } - // reduce have reduction axis and parallel axis, reduction loop must write - // result back - // if (mlir::isa(op)) { - // if (!srcOpCanoniclizedMap.contains(op)) { - // generateEmptyTensorAndWrite(op, srcOpCanoniclizedMap); - // } - // } - }); - // If the group operations do not have result need to be returned, these are - // useless code. - for (auto [idx, grp] : enumerate(opGroups)) { - if (groupOpResults[idx].empty()) { - std::queue().swap(grp); + if (mlir::isa(op) && !movedOperationSet.contains(op)) { + auto parentBlock = op->getBlock(); + op->moveBefore(parentBlock, parentBlock->getOperations().begin()); + movedOperationSet.insert(op); } - } + }); + analysisEmptyGroupAndMaxSteps(); LDBG("Complete analysis group operation results\n"); } void CanonicalizerVectorOperation::analysisGroupOperaionOperandsResults() { - - // Operands - // analysisGroupOperationOperands(opGroups, opGroupIndexMap); - // Results analysisGroupOperationResults(); } @@ -2150,9 +2233,12 @@ bool CanonicalizerVectorOperation::isGroupHasSpecialOperation( auto &rdCanonicalizer = commonUsedData.getMultiRdCanonicalizer()[grpIdx]; auto &bcCanonicalizer = commonUsedData.getBroadcastCanonicalizer()[grpIdx]; auto &tpCanonicalizer = commonUsedData.getTransposeCanonicalizer()[grpIdx]; + auto &shapeCastCanonicalizer = + commonUsedData.getShapeCastCanonicalizer()[grpIdx]; return !rdCanonicalizer.getCandidateOps().empty() or !bcCanonicalizer.getCandidateOps().empty() or - !tpCanonicalizer.getCandidateOps().empty(); + !tpCanonicalizer.getCandidateOps().empty() or + !shapeCastCanonicalizer.getCandidateOps().empty(); } void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { @@ -2177,8 +2263,8 @@ void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { IRRewriter rewriter(grp.back()); rewriter.setInsertionPointAfter(grp.back()); // 1. Rewrite operation as vectorized form - rewriteOperationAsVectorize(rewriter, idx); // 2. Generate loop + rewriteOperationAsVectorize(rewriter, idx); auto forOp = generateVectorizedForLoop(rewriter, groupOpResults[idx], groupOpIterArgs[idx], opShapes, grp, opPermuationMap); diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index c55674542..749facc6b 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -77,14 +77,24 @@ LogicalResult lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp) { /// Need to check if the reassociation are static/constant. LogicalResult -lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp expandOp) { - - if (llvm::any_of(expandOp.getReassociation(), [](Attribute x) { - return !getConstantIntValue(x).has_value(); - })) { - LDBG("Reassociation must be constant: " << expandOp << "\n"); +lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp) { + auto isShapeStatic = [](Value v) { + auto type = mlir::dyn_cast(v.getType()); + if (!type) { + LDBG("Operation type error: " << v << "\n"); + return false; + } + return type.hasStaticShape(); + }; + if (!isShapeStatic(collapseOp->getResults()[0])) { + LDBG("Output shape must be static: " << collapseOp << "\n"); + return failure(); + } + if (!isShapeStatic(collapseOp.getSrc())) { + LDBG("Input shape must be static: " << collapseOp << "\n"); return failure(); } + return success(); } @@ -120,104 +130,130 @@ LogicalResult lowerTargetOpPrecondition(Operation *op) { .Default([](auto) { return failure(); }); } -/// Create a TransferReadOp from `source` with static shape `readShape`. -Value createTransferRead(OpBuilder &builder, Location loc, Value source, - ArrayRef readShape) { - assert(llvm::none_of(readShape, - [](int64_t s) { return s == ShapedType::kDynamic; })); - assert(source && " source null."); - auto shapedType = mlir::dyn_cast(source.getType()); - auto sourceShape = shapedType.getShape(); - auto vectorType = VectorType::get(readShape, shapedType.getElementType()); - - auto padValue = builder.create( - loc, builder.getZeroAttr(shapedType.getElementType())); - assert(sourceShape.size() == readShape.size()); - int64_t readRank = readShape.size(); - auto zero = builder.create(loc, 0); - SmallVector inBoundsVal(readRank, true); - auto transferReadOp = builder.create( - loc, - /*vectorType=*/vectorType, - /*source=*/source, - /*indices=*/SmallVector(readRank, zero), - /*padding=*/padValue, - /*inBounds=*/inBoundsVal); +Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, + Value input, + SmallVector destSizes, + ArrayRef inputVectorSizes, + bool useInBoundsInsteadOfMasking) { - if (llvm::equal(readShape, sourceShape)) { - return transferReadOp; - } else { - assert(false && "wrong shape."); - } -} - -/// create an empty destination tensor and create a TransferWriteOp from the -/// input to the empty tensor. -Operation *createTransferWrite(OpBuilder &builder, Location loc, Value input, - SmallVector destSizes, - ArrayRef inputVectorSizes) { auto inputType = cast(input.getType()); Value dest = builder.create(loc, destSizes, inputType.getElementType()); int64_t rank = cast(dest.getType()).getRank(); auto zero = builder.create(loc, 0); + auto destShape = cast(dest.getType()).getShape(); + SmallVector inBoundsVal(rank, true); + if (useInBoundsInsteadOfMasking) { + // Update the inBounds attribute. + for (unsigned i = 0; i < rank; i++) + inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) && + !ShapedType::isDynamic(destShape[i]); + } Operation *write = builder.create( loc, /*vector=*/input, /*source=*/dest, /*indices=*/SmallVector(rank, zero), - /*inBounds=*/SmallVector(rank, true)); - auto destShape = cast(dest.getType()).getShape(); + /*inBounds=*/inBoundsVal); assert(llvm::none_of( destShape.drop_front(inputVectorSizes.size()), [](int64_t size) { return size == ShapedType::kDynamic; }) && - "InputVectorSizes may be dynamic"); + "Only dims aligned with inputVectorSizes may be dynamic"); + if (useInBoundsInsteadOfMasking) + return write; + bool needMaskForWrite = !llvm::equal( + inputVectorSizes, destShape.take_front(inputVectorSizes.size())); + if (needMaskForWrite) { + SmallVector writeMaskShape; + writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end()); + writeMaskShape.append(destShape.begin() + inputVectorSizes.size(), + destShape.end()); + auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type()); + Value maskForWrite = + builder.create(loc, writeMaskType, destSizes); + write = mlir::vector::maskOperation(builder, write, maskForWrite); + } return write; } +Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, + ArrayRef readShape, Value padValue, + bool useInBoundsInsteadOfMasking) { + assert(llvm::none_of(readShape, + [](int64_t s) { return s == ShapedType::kDynamic; }) && + "expected static shape"); + auto sourceShapedType = cast(source.getType()); + auto sourceShape = sourceShapedType.getShape(); + assert(sourceShape.size() == readShape.size() && "expected same ranks."); + auto maskType = VectorType::get(readShape, builder.getI1Type()); + auto vectorType = VectorType::get(readShape, padValue.getType()); + assert(padValue.getType() == sourceShapedType.getElementType() && + "expected same pad element type to match source element type"); + int64_t readRank = readShape.size(); + auto zero = builder.create(loc, 0); + SmallVector inBoundsVal(readRank, true); + if (useInBoundsInsteadOfMasking) { + // Update the inBounds attribute. + for (unsigned i = 0; i < readRank; i++) + inBoundsVal[i] = (sourceShape[i] == readShape[i]) && + !ShapedType::isDynamic(sourceShape[i]); + } + auto transferReadOp = builder.create( + loc, + /*vectorType=*/vectorType, + /*source=*/source, + /*indices=*/SmallVector(readRank, zero), + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); + + if (llvm::equal(readShape, sourceShape) || useInBoundsInsteadOfMasking) + return transferReadOp; + SmallVector mixedSourceDims = + tensor::getMixedSizes(builder, loc, source); + Value mask = + builder.create(loc, maskType, mixedSourceDims); + return mlir::vector::maskOperation(builder, transferReadOp, mask) + ->getResult(0); +} + /// Vectorize a `tensor::expandshape` to these 3 Ops: /// Vector::TransferReadOp - Reads a vector from the source tensor /// ShapeCastOp - Reshape the data based on the target. /// vector::TransferWriteOp. - Write the result vector back to the destination /// tensor template -LogicalResult lowerTensorExpandShapeOp(RewriterBase &rewriter, T expandShapeOp, +LogicalResult lowerTensorExpandShapeOp(RewriterBase &rewriter, + Operation *inputOp, SmallVectorImpl &newResults) { OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(expandShapeOp); - - RankedTensorType expandShapeTensorType = expandShapeOp.getSrcType(); - - SmallVector readMaskShape; - ArrayRef sourceShape = expandShapeTensorType.getShape(); - ArrayRef resultShape = expandShapeOp.getResultType().getShape(); - readMaskShape.append(sourceShape.begin(), sourceShape.end()); - - ReifiedRankedShapedTypeDims reifiedRetShapes; - LogicalResult status = - cast(expandShapeOp.getOperation()) - .reifyResultShapes(rewriter, reifiedRetShapes); - if (status.failed()) { - LDBG("Unable to reify result shapes of " << expandShapeOp << "\n"); - return failure(); - } - Location loc = expandShapeOp->getLoc(); - - // Read result, mask if necessary. If transferReadOp shape is not equal - // to shape of source, then a mask is necessary. - Value readResult = createTransferRead( - rewriter, loc, expandShapeOp.getSrc(), - ArrayRef(readMaskShape.begin(), readMaskShape.end())); - - auto resultVectorType = - VectorType::get(resultShape, expandShapeTensorType.getElementType()); + rewriter.setInsertionPoint(inputOp); + auto src = inputOp->getOperand(0); + auto srcType = mlir::dyn_cast(src.getType()); + auto result = inputOp->getResults()[0]; + auto resultType = mlir::dyn_cast(result.getType()); + + ArrayRef resultShape = resultType.getShape(); + Location loc = inputOp->getLoc(); + + // read + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(srcType.getElementType())); + Value readResult = createReadOrMaskedRead( + rewriter, loc, src, srcType.getShape(), padValue, false); + + auto shapeCastType = + VectorType::get(resultType.getShape(), resultType.getElementType()); vector::ShapeCastOp shapeCastOp = - rewriter.create(loc, resultVectorType, readResult); + rewriter.create(loc, shapeCastType, readResult); - SmallVector writeMaskShape( - shapeCastOp.getResultVectorType().getShape()); - Operation *write = createTransferWrite(rewriter, loc, shapeCastOp.getResult(), - reifiedRetShapes[0], writeMaskShape); + // write + SmallVector destSizes; + for (auto size : resultShape) { + destSizes.emplace_back(rewriter.getIndexAttr(size)); + } + Operation *write = + createWriteOrMaskedWrite(rewriter, loc, shapeCastOp->getResults()[0], + destSizes, resultShape, false); newResults.push_back(write->getResult(0)); return success(); } @@ -234,19 +270,14 @@ LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, rewriter.setInsertionPoint(bitCastOp); auto sourceType = bitCastOp.getSource().getType(); - auto sourceShape = sourceType.getShape(); auto resultType = bitCastOp.getResult().getType(); auto resultShape = resultType.getShape(); - - SmallVector readMaskShape; - readMaskShape.append(sourceShape.begin(), sourceShape.end()); Location loc = bitCastOp->getLoc(); - // Read result, mask if necessary. If transferReadOp shape is not equal - // to shape of source, then a mask is necessary. - Value readResult = createTransferRead( - rewriter, loc, bitCastOp->getOperand(0), - ArrayRef(readMaskShape.begin(), readMaskShape.end())); + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(sourceType.getElementType())); + Value readResult = createReadOrMaskedRead( + rewriter, loc, bitCastOp.getSource(), resultShape, padValue, false); auto resultVectorType = VectorType::get(resultShape, resultType.getElementType()); @@ -259,8 +290,8 @@ LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, for (auto size : resultShape) destSizes.emplace_back(rewriter.getIndexAttr(size)); auto write = - createTransferWrite(rewriter, loc, vectorbitCastOp->getResults()[0], - destSizes, writeMaskShape); + createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0), + destSizes, resultShape, false); newResults.push_back(write->getResults()[0]); return success(); } @@ -293,6 +324,10 @@ LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, rewriter.create(loc, rewriter.getIndexAttr(dim)); int64_t rank = concatOp.getResultType().getRank(); + auto srcType = + mlir::dyn_cast(concatOp->getResultTypes()[0]); + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(srcType.getElementType())); // Construct the chain of insert_slice ops into the destination. Value result = *dest; @@ -302,13 +337,12 @@ LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, SmallVector sizes = tensor::getMixedSizes(rewriter, loc, input); SmallVector readMaskShape; - auto inputType = llvm::cast(input.getType()); + auto inputType = mlir::dyn_cast(input.getType()); auto sourceShape = inputType.getShape(); readMaskShape.append(sourceShape.begin(), sourceShape.end()); - Value readResult = createTransferRead( - rewriter, loc, input, - ArrayRef(readMaskShape.begin(), readMaskShape.end())); + Value readResult = createReadOrMaskedRead(rewriter, loc, input, sourceShape, + padValue, false); Value zero = rewriter.create(loc, 0); SmallVector indices(rank, zero); indices[dim] = previous_offset; From 947e5c0fb47fe8399ddd6f5c0e22eefd5551321b Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 5 Jul 2024 16:30:15 +0800 Subject: [PATCH 12/66] temp record, please reset back --- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index f98758c07..2795951c9 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -1439,15 +1439,12 @@ void checkAndSetOperand( if (llvm::dyn_cast(op) || llvm::dyn_cast(op)) { assert(opPermuationMap.contains(op)); - std::cout << "op verify ..." << std::endl; - op->dump(); auto permutationMap = opPermuationMap.at(op); auto dimExpr = permutationMap.getResults(); for (auto [idx, x] : llvm::enumerate(dimExpr)) { if (mlir::dyn_cast(x)) { auto dim = mlir::dyn_cast(x).getPosition(); - std::cout << inductionVars.size() << "," << dim << std::endl; op->setOperand(dim + offset, inductionVars[dim]); } } From 1101e86bce925a2d1654da4064c2f1a2594de370 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 5 Jul 2024 16:30:56 +0800 Subject: [PATCH 13/66] temp record, please reset back --- .../gc/transforms/cpu-vetor-distribution.mlir | 345 +++++++++--------- 1 file changed, 165 insertions(+), 180 deletions(-) diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 32ee8414f..64534ebba 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -8,189 +8,174 @@ func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4 return %2 : tensor<11008x4096xf32> } -// func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, -// %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) -// -> tensor<512x512xf32> { -// // Matrix-matrix multiplication. -// %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) -// outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> +func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> -// // Elementwise addition. -// %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } -// ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) -// outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + // Elementwise addition. + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> -// // Elementwise max with 0 (ReLU). -// %c0f = arith.constant 0.0 : f32 -// // expected-remark @below {{elementwise binary}} -// %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } -// ins(%biased, %c0f : tensor<512x512xf32>, f32) -// outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> -// func.return %relued : tensor<512x512xf32> -// } - -// func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { -// %0 = tensor.empty() : tensor<16x64xf32> -// %reduce = linalg.reduce -// ins(%arg0:tensor<16x32x64xf32>) -// outs(%0:tensor<16x64xf32>) -// dimensions = [1] -// (%in: f32, %out: f32) { -// %1 = arith.addf %out, %in: f32 -// linalg.yield %1: f32 -// } -// %2 = tensor.expand_shape %reduce [[0],[1, 2]] : tensor<16x64xf32> into tensor<16x1x64xf32> -// return %2 : tensor<16x1x64xf32> -// } + // Elementwise max with 0 (ReLU). + %c0f = arith.constant 0.0 : f32 + // expected-remark @below {{elementwise binary}} + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} -// func.func @insert_pad_into_fill(%input: tensor, %low0: index, %low1: index, %high1: index, %high2: index) -> tensor<8x384x384xf32> { -// %f0 = arith.constant 0.0 : f32 -// %c0 = arith.constant 0 : index -// %pad = tensor.pad %input low[%low0, %low1, %c0] high[%c0, %high1, %high2] { -// ^bb0(%arg3: index, %arg4: index, %arg5: index): -// tensor.yield %f0 : f32 -// } : tensor to tensor<8x128x128xf32> -// %empty = tensor.empty() : tensor<8x384x384xf32> -// %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32> -// %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> -// return %0: tensor<8x384x384xf32> -// } +func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { + %0 = tensor.empty() : tensor<16x64xf32> + %reduce = linalg.reduce + ins(%arg0:tensor<16x32x64xf32>) + outs(%0:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %1 = arith.addf %out, %in: f32 + linalg.yield %1: f32 + } + %2 = tensor.expand_shape %reduce [[0],[1, 2]] output_shape [16, 1, 64] : tensor<16x64xf32> into tensor<16x1x64xf32> + return %2 : tensor<16x1x64xf32> +} -// #map = affine_map<(d0) -> (d0 * 64)> -// #map1 = affine_map<(d0) -> (d0 * 128)> -// #map2 = affine_map<(d0) -> (d0 floordiv 16)> -// #map3 = affine_map<(d0) -> (d0 floordiv 32)> -// #map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> -// #map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> -// module { -// func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { -// %c32 = arith.constant 32 : index -// %c512 = arith.constant 512 : index -// %c128 = arith.constant 128 : index -// %c64 = arith.constant 64 : index -// %c0 = arith.constant 0 : index -// %cst = arith.constant 0.000000e+00 : bf16 -// %0 = tensor.empty() : tensor<128x256xbf16> -// %1 = tensor.empty() : tensor<512x256xbf16> -// %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { -// %3 = affine.apply #map(%arg3) -// %4 = affine.apply #map1(%arg4) -// %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> -// %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> -// %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> -// %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> -// %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> -// %11 = affine.apply #map2(%arg28) -// %12 = affine.apply #map3(%arg24) -// %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> -// %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> -// %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> -// %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> -// %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> -// %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %15 = affine.apply #map4(%arg12, %arg24, %arg4) -// %16 = affine.apply #map5(%arg8, %arg20, %arg3) -// %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> -// %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> -// %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] -// %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.forall.in_parallel { -// tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// } -// } -// return %2#2 : tensor<128x256xbf16> -// } -// } +#map = affine_map<(d0) -> (d0 * 64)> +#map1 = affine_map<(d0) -> (d0 * 128)> +#map2 = affine_map<(d0) -> (d0 floordiv 16)> +#map3 = affine_map<(d0) -> (d0 floordiv 32)> +#map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> +#map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> + func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = tensor.empty() : tensor<512x256xbf16> + %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { + %3 = affine.apply #map(%arg3) + %4 = affine.apply #map1(%arg4) + %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> + %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> + %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> + %11 = affine.apply #map2(%arg28) + %12 = affine.apply #map3(%arg24) + %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> + %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> + %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> + %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> + %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> + %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %15 = affine.apply #map4(%arg12, %arg24, %arg4) + %16 = affine.apply #map5(%arg8, %arg20, %arg3) + %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> + %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> + %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] + %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + } + } + return %2#2 : tensor<128x256xbf16> + } -// func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { -// %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> -// %1 = tensor.empty() : tensor<8192x16384xf32> -// %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { -// ^bb0(%in: f32, %in_0: f32, %out: f32): -// %4 = arith.addf %in, %in_0 : f32 -// linalg.yield %4 : f32 -// } -> tensor<8192x16384xf32> -// %c0 = arith.constant 0 : index -// %c8192 = arith.constant 8192 : index -// %c128 = arith.constant 128 : index -// %3 = scf.for %arg4 = %c0 to %c8192 step %c128 iter_args(%arg5 = %1) -> (tensor<8192x16384xf32>) { -// %c0_0 = arith.constant 0 : index -// %c16384 = arith.constant 16384 : index -// %c128_1 = arith.constant 128 : index -// %4 = scf.for %arg6 = %c0_0 to %c16384 step %c128_1 iter_args(%arg7 = %arg5) -> (tensor<8192x16384xf32>) { -// %extracted_slice = tensor.extract_slice %arg0[%arg4, 0] [128, 12288] [1, 1] : tensor<8192x12288xf16> to tensor<128x12288xf16> -// %extracted_slice_2 = tensor.extract_slice %arg1[0, %arg6] [12288, 128] [1, 1] : tensor<12288x16384xf16> to tensor<12288x128xf16> -// %extracted_slice_3 = tensor.extract_slice %arg2[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> -// %5 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice, %extracted_slice_2 : tensor<128x12288xf16>, tensor<12288x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf32>) -> tensor<128x128xf32> -// %extracted_slice_4 = tensor.extract_slice %0[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> -// %extracted_slice_5 = tensor.extract_slice %arg3[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> -// %extracted_slice_6 = tensor.extract_slice %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> -// %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_5 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%extracted_slice_6 : tensor<128x128xf32>) attrs = {__root_op__ = 0 : i64} { -// ^bb0(%in: f32, %in_9: f32, %out: f32): -// %8 = arith.addf %in, %in_9 : f32 -// linalg.yield %8 : f32 -// } -> tensor<128x128xf32> -// %c0_7 = arith.constant 0 : index -// %c128_8 = arith.constant 128 : index -// %c32 = arith.constant 32 : index -// %7 = scf.for %arg8 = %c0_7 to %c128_8 step %c32 iter_args(%arg9 = %extracted_slice_6) -> (tensor<128x128xf32>) { -// %c0_9 = arith.constant 0 : index -// %c128_10 = arith.constant 128 : index -// %c32_11 = arith.constant 32 : index -// %8 = scf.for %arg10 = %c0_9 to %c128_10 step %c32_11 iter_args(%arg11 = %arg9) -> (tensor<128x128xf32>) { -// %extracted_slice_12 = tensor.extract_slice %extracted_slice[%arg8, 0] [32, 12288] [1, 1] : tensor<128x12288xf16> to tensor<32x12288xf16> -// %extracted_slice_13 = tensor.extract_slice %extracted_slice_2[0, %arg10] [12288, 32] [1, 1] : tensor<12288x128xf16> to tensor<12288x32xf16> -// %extracted_slice_14 = tensor.extract_slice %extracted_slice_3[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> -// %9 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice_12, %extracted_slice_13 : tensor<32x12288xf16>, tensor<12288x32xf16>) outs(%extracted_slice_14 : tensor<32x32xf32>) -> tensor<32x32xf32> -// %extracted_slice_15 = tensor.extract_slice %5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> -// %extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> -// %extracted_slice_17 = tensor.extract_slice %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> -// %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %extracted_slice_16 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice_17 : tensor<32x32xf32>) attrs = {__root_op__ = 0 : i64} { -// ^bb0(%in: f32, %in_19: f32, %out: f32): -// %11 = arith.addf %in, %in_19 : f32 -// linalg.yield %11 : f32 -// } -> tensor<32x32xf32> -// %inserted_slice_18 = tensor.insert_slice %10 into %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x128xf32> -// scf.yield %inserted_slice_18 : tensor<128x128xf32> -// } {__parallel_loop__ = 1 : i64} -// scf.yield %8 : tensor<128x128xf32> -// } {__parallel_loop__ = 1 : i64} -// %inserted_slice = tensor.insert_slice %7 into %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<8192x16384xf32> -// scf.yield %inserted_slice : tensor<8192x16384xf32> -// } {__parallel_loop__ = 0 : i64} -// scf.yield %4 : tensor<8192x16384xf32> -// } {__parallel_loop__ = 0 : i64} -// return %3 : tensor<8192x16384xf32> -// } +func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { + %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %1 = tensor.empty() : tensor<8192x16384xf32> + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<8192x16384xf32> + %c0 = arith.constant 0 : index + %c8192 = arith.constant 8192 : index + %c128 = arith.constant 128 : index + %3 = scf.for %arg4 = %c0 to %c8192 step %c128 iter_args(%arg5 = %1) -> (tensor<8192x16384xf32>) { + %c0_0 = arith.constant 0 : index + %c16384 = arith.constant 16384 : index + %c128_1 = arith.constant 128 : index + %4 = scf.for %arg6 = %c0_0 to %c16384 step %c128_1 iter_args(%arg7 = %arg5) -> (tensor<8192x16384xf32>) { + %extracted_slice = tensor.extract_slice %arg0[%arg4, 0] [128, 12288] [1, 1] : tensor<8192x12288xf16> to tensor<128x12288xf16> + %extracted_slice_2 = tensor.extract_slice %arg1[0, %arg6] [12288, 128] [1, 1] : tensor<12288x16384xf16> to tensor<12288x128xf16> + %extracted_slice_3 = tensor.extract_slice %arg2[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %5 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice, %extracted_slice_2 : tensor<128x12288xf16>, tensor<12288x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf32>) -> tensor<128x128xf32> + %extracted_slice_4 = tensor.extract_slice %0[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %extracted_slice_5 = tensor.extract_slice %arg3[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %extracted_slice_6 = tensor.extract_slice %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_5 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%extracted_slice_6 : tensor<128x128xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %8 = arith.addf %in, %in_9 : f32 + linalg.yield %8 : f32 + } -> tensor<128x128xf32> + %c0_7 = arith.constant 0 : index + %c128_8 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %7 = scf.for %arg8 = %c0_7 to %c128_8 step %c32 iter_args(%arg9 = %extracted_slice_6) -> (tensor<128x128xf32>) { + %c0_9 = arith.constant 0 : index + %c128_10 = arith.constant 128 : index + %c32_11 = arith.constant 32 : index + %8 = scf.for %arg10 = %c0_9 to %c128_10 step %c32_11 iter_args(%arg11 = %arg9) -> (tensor<128x128xf32>) { + %extracted_slice_12 = tensor.extract_slice %extracted_slice[%arg8, 0] [32, 12288] [1, 1] : tensor<128x12288xf16> to tensor<32x12288xf16> + %extracted_slice_13 = tensor.extract_slice %extracted_slice_2[0, %arg10] [12288, 32] [1, 1] : tensor<12288x128xf16> to tensor<12288x32xf16> + %extracted_slice_14 = tensor.extract_slice %extracted_slice_3[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %9 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice_12, %extracted_slice_13 : tensor<32x12288xf16>, tensor<12288x32xf16>) outs(%extracted_slice_14 : tensor<32x32xf32>) -> tensor<32x32xf32> + %extracted_slice_15 = tensor.extract_slice %5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %extracted_slice_17 = tensor.extract_slice %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %extracted_slice_16 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice_17 : tensor<32x32xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_19: f32, %out: f32): + %11 = arith.addf %in, %in_19 : f32 + linalg.yield %11 : f32 + } -> tensor<32x32xf32> + %inserted_slice_18 = tensor.insert_slice %10 into %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x128xf32> + scf.yield %inserted_slice_18 : tensor<128x128xf32> + } {__parallel_loop__ = 1 : i64} + scf.yield %8 : tensor<128x128xf32> + } {__parallel_loop__ = 1 : i64} + %inserted_slice = tensor.insert_slice %7 into %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<8192x16384xf32> + scf.yield %inserted_slice : tensor<8192x16384xf32> + } {__parallel_loop__ = 0 : i64} + scf.yield %4 : tensor<8192x16384xf32> + } {__parallel_loop__ = 0 : i64} + return %3 : tensor<8192x16384xf32> +} From 7200565175b9d8810b05c71003ff38ed82bee475 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 8 Jul 2024 16:20:04 +0800 Subject: [PATCH 14/66] add check test --- include/gc/Transforms/TilingVector.h | 135 ++-- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 638 ++++++++++-------- .../gc/transforms/cpu-vetor-distribution.mlir | 144 ++-- 3 files changed, 472 insertions(+), 445 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index 6f54d8a4b..d8abce0fe 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -55,7 +55,7 @@ namespace gc { namespace { Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); -void checkAndSetOperand( +void setOperationCorrectOperand( Operation *op, const ValueRange &iterArgs, const llvm::DenseMap &operandIdxMap, const llvm::SmallVector &inductionVars, @@ -84,8 +84,6 @@ class VectorFusionStrategy { // can fused into prev operation which axis position llvm::DenseMap opAnchorPos; - llvm::SmallVector, 8> ignoreInitOperations; - func::FuncOp func; public: @@ -98,9 +96,6 @@ class VectorFusionStrategy { llvm::SmallVector &getGroupMaxSteps() { return groupMaxSteps; } func::FuncOp getFunc() { return func; } - llvm::SmallVector, 8> getIgnoreInitOperations() { - return ignoreInitOperations; - } VectorFusionStrategy() = default; VectorFusionStrategy(func::FuncOp func) : func(func) {} @@ -126,8 +121,7 @@ template class SpecialOperationCanonicalizer { }; class MultiReductionCanonicalizer - : virtual public SpecialOperationCanonicalizer< - vector::MultiDimReductionOp> { + : public SpecialOperationCanonicalizer { private: llvm::SmallVector reductionAxis, parallelAxis; std::queue prevOps, postOps, accRelatedOps, sourceRelatedOps; @@ -169,7 +163,7 @@ class MultiReductionCanonicalizer }; class BroadcastCanonicalizer - : virtual public SpecialOperationCanonicalizer { + : public SpecialOperationCanonicalizer { private: public: BroadcastCanonicalizer( @@ -179,7 +173,7 @@ class BroadcastCanonicalizer }; class TransposeCanonicalizer - : virtual public SpecialOperationCanonicalizer { + : public SpecialOperationCanonicalizer { private: public: TransposeCanonicalizer( @@ -189,7 +183,7 @@ class TransposeCanonicalizer }; class ShapeCastCanonicalizer - : virtual public SpecialOperationCanonicalizer { + : public SpecialOperationCanonicalizer { private: public: ShapeCastCanonicalizer( @@ -224,6 +218,7 @@ class CanonicalizerCommonUsedData { llvm::DenseMap &opPermuationMap) : fusionStrategy(fusionStrategy), groupOpResults(groupOpResults), groupOpIterArgs(groupOpIterArgs), opPermuationMap(opPermuationMap) {} + virtual ~CanonicalizerCommonUsedData(){}; // set methods void setFuseStrategy(VectorFusionStrategy &strategy) { @@ -266,86 +261,116 @@ class CanonicalizerCommonUsedData { return opPermuationMap; } - llvm::SmallVector &getMultiRdCanonicalizer() { + llvm::SmallVector & + getMultiRdCanonicalizers() { return multiRdCanonicalizers; } - llvm::SmallVector &getBroadcastCanonicalizer() { + llvm::SmallVector &getBroadcastCanonicalizers() { return broadcastCanonicalizers; } - llvm::SmallVector &getTransposeCanonicalizer() { + llvm::SmallVector &getTransposeCanonicalizers() { return transposeCanonicalizers; } - llvm::SmallVector &getShapeCastCanonicalizer() { + llvm::SmallVector &getShapeCastCanonicalizers() { return shapeCastCanonicalizers; } // other methods - void initSpeicalOperationCanonicalizers(); + bool isGroupHasSpecialOperation(const size_t grpIdx); }; -class CanonicalizerVectorOperation { +class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { + func::FuncOp func; + +public: + virtual ~ForLoopGenerator() {} + void setGeneratorFunc(func::FuncOp &func) { this->func = func; } + void generateGroupOpVectorizedIR(const int idx); + void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, + const std::queue &queue = {}); + void createNewConstantOp(Operation *srcOp, + vector::TransferWriteOp *transferWriteOp); + // elementwise for loop + mlir::FailureOr + generateVectorizedForLoop(const size_t groupId, IRRewriter &rewriter, + const VectorType &vectorType); + scf::ForOp + constructNestedForOp(const size_t forDimIdx, const size_t groupIdx, + OpBuilder &b, const Location &loc, + const ValueRange &iterArgs, const VectorType &type, + const llvm::ArrayRef &dims, + llvm::SmallVector &inductionVars, + const llvm::DenseMap &operandIdxMap); + void moveOperationsToCurrentForBody( + const size_t groupIdx, OpBuilder &b, + const llvm::SmallVector &inductionVars, + const llvm::DenseMap &operandIdxMap, + const ValueRange &loopState, const std::queue &queue = {}); + + // multireduction forloop methods + scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); + scf::ForOp + parallelAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, + const size_t parallelIdx, ValueRange &initArgs, + llvm::SmallVector &inductionVars, + Value &originalWriteResult); + + scf::ForOp + reductionAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, + const size_t reductionIdx, ValueRange &initArgs, + llvm::SmallVector &inductionVars); +}; + +class VectorOperationAnalysizer : virtual public CanonicalizerCommonUsedData { +private: + func::FuncOp func; + +public: + virtual ~VectorOperationAnalysizer(){}; + void generateEmptyTensorAndWrite( + Operation *sourceOp, llvm::DenseMap> + &srcOpCanoniclizedMap); + void setAnalysisFunc(func::FuncOp &func) { this->func = func; } + void analysisEmptyGroupAndMaxSteps(); + void analysisGroupOperaion(); + void analysisGroupOperationResults(); +}; + +class CanonicalizerVectorOperation : virtual public VectorOperationAnalysizer, + ForLoopGenerator { private: func::FuncOp func; IRRewriter rewriter; CanonicalizerKind kind; - CanonicalizerCommonUsedData commonUsedData; public: CanonicalizerVectorOperation( func::FuncOp func, CanonicalizerKind kind = CanonicalizerKind::OperationsGroup) : func(func), rewriter(func), kind(kind) { + setAnalysisFunc(func); + setGeneratorFunc(func); // vector operation fusion if (kind == CanonicalizerKind::OperationsGroup) { auto fusionStrategy = VectorFusionStrategy(func); fusionStrategy.run(); - commonUsedData.setFuseStrategy(fusionStrategy); + setFuseStrategy(fusionStrategy); } } + virtual ~CanonicalizerVectorOperation(){}; // get functions func::FuncOp &getFunc() { return func; }; IRRewriter &getIRWewriter() { return rewriter; } - CanonicalizerCommonUsedData &getCommonUsedData() { return commonUsedData; } - - void generateGroupOpVectorizedIR(const int idx); - - void analysisEmptyGroupAndMaxSteps(); - void analysisGroupOperaionOperandsResults(); - - void generateEmptyTensorAndWrite( - Operation *sourceOp, llvm::DenseMap> - &srcOpCanoniclizedMap); - void analysisGroupOperationResults(); - - LogicalResult canonicalizeReductionOperation(); - LogicalResult canonicalizeTransposeOperation(vector::TransposeOp &transposeOp, - IRRewriter &rewriter); - void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, - const std::queue &queue = {}); - void createNewConstantOp(Operation *srcOp, - vector::TransferWriteOp *transferWriteOp); - - // special operation methods - scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); - void getCandidateSpecialOps(); + // void canonicalizeSpecialOperation(); - - scf::ForOp - parallelAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, - const size_t parallelIdx, ValueRange &initArgs, - llvm::SmallVector &inductionVars, - Value &originalWriteResult); - - scf::ForOp - reductionAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, - const size_t reductionIdx, ValueRange &initArgs, - llvm::SmallVector &inductionVars); - - bool isGroupHasSpecialOperation(const size_t grpIdx); + LogicalResult canonicalizeReductionOperation(); + void clearSpecialOperationCanonicalizers(); + void dummyInitSpecialOperation(); + void initSpeicalOperationCanonicalizers(); void run(); }; diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 2795951c9..2d77c2b63 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -25,19 +25,21 @@ struct HardwareInfo { bool favx2 = true; } HW; -void printGroupOps(llvm::SmallVector, 8> &opGroups) { - for (auto grp : opGroups) { - if (grp.empty()) { - continue; - } - std::queue tmpQ(grp); - while (!tmpQ.empty()) { - auto cur = tmpQ.front(); - tmpQ.pop(); - cur->dump(); - } - } -} +// void printGroupOps(llvm::SmallVector, 8> &opGroups) { +// for (auto grp : opGroups) { +// if (grp.empty()) { +// continue; +// } +// std::cout << "__________________ group start_____________" << std::endl; +// std::queue tmpQ(grp); +// while (!tmpQ.empty()) { +// auto cur = tmpQ.front(); +// tmpQ.pop(); +// cur->dump(); +// } +// std::cout << "__________________ group end_____________" << std::endl; +// } +// } bool isSpecialOp(Operation *op) { return llvm::isa(op) || @@ -60,8 +62,46 @@ bool is_innermost_operation(Operation *op) { return inner_most; } +bool isNotSupportOperation(Operation *op) { + return llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op) || + llvm::isa(op); +} + +bool isReadOrWriteOperation(Operation *op) { + return llvm::isa(op) || + llvm::isa(op); +} + +// TODO: Need to support these operations in the future +bool hasNotSupportOperation(func::FuncOp *func) { + auto walkRes = func->walk([](Operation *op) { + if (isNotSupportOperation(op)) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return walkRes != WalkResult::advance(); +} + +// select nearest even step +int get_nearest_vector_step(const int step) { + assert(step > 0); + int nbits = 0, n = step; + while (n) { + n = n >> 1; + nbits++; + } + assert(nbits <= 6 || (nbits == 7 && step == 64)); + return (1 << (nbits - 1)) == step ? step : (1 << nbits); +} + int generateValidSteps(int steps, VectorType type) { - return type.getShape().back() >= steps ? (steps > 16 ? 16 : steps) : steps; + return type.getShape().back() >= steps + ? (steps > 16 ? 16 : steps) + : get_nearest_vector_step(type.getShape().back()); } // expr equals `vector rank` - 1 @@ -70,6 +110,22 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { mlir::dyn_cast(expr).getPosition() == rank - 1; } +[[nodiscard]] int getDataTypeValidSteps(const VectorType &type) { + auto typebits = type.getElementTypeBitWidth(); + const int favx512bits = 512; + const int favx2bits = 256; + if (HW.favx512f) { + return generateValidSteps(favx512bits / typebits, type); + } else if (HW.favx2) { + return generateValidSteps(favx2bits / typebits, type); + } else { + // invalid + LDBG("Please check the hardware information."); + assert(false && "Invalid hardware."); + return -1; + } +} + // Get the maximum number of current data types that a register can hold [[nodiscard]] int getDataTypeMAXSIMDLength(const VectorType &type) { auto typebits = type.getElementTypeBitWidth(); @@ -93,37 +149,47 @@ mlir::FailureOr getOperationVectorType(Operation *op) { if (!op) { return failure(); } - return TypeSwitch>(op) - .Case( - [&](vector::TransferWriteOp transferWriteOp) - -> mlir::FailureOr { - auto retType = mlir::dyn_cast( - transferWriteOp->getOperand(0).getType()); - if (retType) { - return retType; + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + auto retType = mlir::dyn_cast( + transferWriteOp->getOperand(0).getType()); + if (retType) { + return retType; + } + LDBG("TransferWrite Operation has wrong vector to write."); + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + return multiReductionOp.getSourceVectorType(); + }) + .Case( + [&](arith::ConstantOp constantOp) { return failure(); }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (!op->getResults().empty()) { + auto t = mlir::dyn_cast(op->getResultTypes().front()); + if (t) { + if (isDynamicType(t)) { + return failure(); + } + return t; + } } - LDBG("TransferWrite Operation has wrong vector to write."); return failure(); - }) - .Case([&](vector::TransferReadOp transferReadOp) - -> mlir::FailureOr { - return transferReadOp.getVectorType(); - }) - .Case( - [&](arith::ConstantOp constantOp) { return failure(); }) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - return multiReductionOp.getSourceVectorType(); - }) - .Default([&](Operation *op) -> mlir::FailureOr { - if (!op->getResults().empty()) { - auto t = mlir::dyn_cast(op->getResultTypes().front()); - if (t) { - return t; - } - } - return failure(); - }); + }); + if (!failed(ret) and isDynamicType(ret.value())) { + return failure(); + } + return ret; } VectorType getVectorzedType(Operation *op, uint32_t loop_step = 0) { @@ -137,7 +203,7 @@ VectorType getVectorzedType(Operation *op, uint32_t loop_step = 0) { } auto vectorizedType = baseType.value(); if (loop_step == 0) { - loop_step = getDataTypeMAXSIMDLength(vectorizedType); + loop_step = getDataTypeValidSteps(vectorizedType); } return VectorType::get({loop_step}, vectorizedType.getElementType()); } @@ -263,7 +329,7 @@ bool isReadWriteOnLastDim(Operation *op) { auto rank = mlir::dyn_cast(op) ? mlir::dyn_cast(op->getOperand(0).getType()).getRank() - : mlir::dyn_cast(op->getOperand(0).getType()).getRank(); + : mlir::dyn_cast(op->getOperand(1).getType()).getRank(); auto dimExpr = permutationMap.getResults(); bool find = false; for (auto &expr : dimExpr) { @@ -274,50 +340,51 @@ bool isReadWriteOnLastDim(Operation *op) { return find; } LDBG("The operation is not a read or write operation." << *op << "\n"); + assert(0 && "The operation is not a read or write operation."); return false; } -std::variant numeric_zero(Type type) { - Type t1 = getElementTypeOrSelf(type); - if (t1.isF32()) { - return 0.f; - } else if (t1.isBF16()) { - return bfloat2float(float2bfloat(0.f)); - } else if (t1.isF16()) { - return half2float(float2half(0.f)); - } else if (t1.isSignedInteger(8)) { - return int64_t(0); - } else if (t1.isSignedInteger(32)) { - return int64_t(0); - } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { - return int64_t(0); - } else { - LDBG("Unsupported data type: " << t1 << "\n"); - assert(0 && "unsupported data type"); - return (int64_t)0; - } -} - -std::variant numeric_one(Type type) { - Type t1 = getElementTypeOrSelf(type); - if (t1.isF32()) { - return 1.f; - } else if (t1.isBF16()) { - return bfloat2float(float2bfloat(1.f)); - } else if (t1.isF16()) { - return half2float(float2half(1.f)); - } else if (t1.isSignedInteger(8)) { - return int64_t(1); - } else if (t1.isSignedInteger(32)) { - return int64_t(1); - } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { - return int64_t(1); - } else { - LDBG("Unsupported data type: " << t1 << "\n"); - assert(0 && "unsupported data type"); - return (int64_t)1; - } -} +// std::variant numeric_zero(Type type) { +// Type t1 = getElementTypeOrSelf(type); +// if (t1.isF32()) { +// return 0.f; +// } else if (t1.isBF16()) { +// return bfloat2float(float2bfloat(0.f)); +// } else if (t1.isF16()) { +// return half2float(float2half(0.f)); +// } else if (t1.isSignedInteger(8)) { +// return int64_t(0); +// } else if (t1.isSignedInteger(32)) { +// return int64_t(0); +// } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { +// return int64_t(0); +// } else { +// LDBG("Unsupported data type: " << t1 << "\n"); +// assert(0 && "unsupported data type"); +// return (int64_t)0; +// } +// } + +// std::variant numeric_one(Type type) { +// Type t1 = getElementTypeOrSelf(type); +// if (t1.isF32()) { +// return 1.f; +// } else if (t1.isBF16()) { +// return bfloat2float(float2bfloat(1.f)); +// } else if (t1.isF16()) { +// return half2float(float2half(1.f)); +// } else if (t1.isSignedInteger(8)) { +// return int64_t(1); +// } else if (t1.isSignedInteger(32)) { +// return int64_t(1); +// } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { +// return int64_t(1); +// } else { +// LDBG("Unsupported data type: " << t1 << "\n"); +// assert(0 && "unsupported data type"); +// return (int64_t)1; +// } +// } std::variant numeric_limits_minimum(Type type) { Type t1 = getElementTypeOrSelf(type); @@ -326,7 +393,8 @@ std::variant numeric_limits_minimum(Type type) { } else if (t1.isBF16()) { return bfloat2float(float2bfloat(-std::numeric_limits::infinity())); } else if (t1.isF16()) { - return (float)-65504; + return (float)half2float( + float2half(-std::numeric_limits::infinity())); } else if (t1.isSignedInteger(8)) { return int64_t(-128); } else if (t1.isSignedInteger(32)) { @@ -347,7 +415,8 @@ std::variant numericLimitsMaximum(Type type) { } else if (t1.isBF16()) { return bfloat2float(float2bfloat(std::numeric_limits::infinity())); } else if (t1.isF16()) { - return (float)65504; + return (float)half2float( + float2half(std::numeric_limits::infinity())); } else if (t1.isSignedInteger(8)) { return int64_t(127); } else if (t1.isSignedInteger(32)) { @@ -733,38 +802,35 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, opBuilder.getIntegerAttr(opBuilder.getIndexType(), x)); } -void moveOperationsToCurrentForBody( - std::queue &opQueue, OpBuilder &b, ValueRange &loopState, - const llvm::DenseMap &operandIdxMap, +void ForLoopGenerator::moveOperationsToCurrentForBody( + const size_t groupIdx, OpBuilder &b, const llvm::SmallVector &inductionVars, - const llvm::DenseMap &opPermuationMap) { - // Operation *lastOperation = opQueue.front(); + const llvm::DenseMap &operandIdxMap, + const ValueRange &loopState, const std::queue &queue) { + std::queue opQueue = + queue.empty() ? getFusionStrategy().getOpGroups()[groupIdx] : queue; + auto &opPermuationMap = getOpPermuationMap(); while (!opQueue.empty()) { auto x = opQueue.front(); opQueue.pop(); - // if (lastOperation == x) { x->moveBefore(b.getBlock(), b.getBlock()->end()); - // } else { - // x->moveAfter(lastOperation); - // lastOperation = x; - // } // check operation type to set correct operand - checkAndSetOperand(x, loopState, operandIdxMap, inductionVars, - opPermuationMap); + setOperationCorrectOperand(x, loopState, operandIdxMap, inductionVars, + opPermuationMap); } } -scf::ForOp CanonicalizerVectorOperation::reductionAxisGenerateForLoop( +scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, ValueRange &initArgs, llvm::SmallVector &inductionVars) { MultiReductionCanonicalizer rdCanonicalizer = - commonUsedData.getMultiRdCanonicalizer()[groupIdx]; + getMultiRdCanonicalizers()[groupIdx]; auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; const auto loc = multireductionOp->getLoc(); auto &reductionAxis = rdCanonicalizer.getReductionAxis(); auto lastDimReduction = rdCanonicalizer.hasLastDimReduction(); auto vectorType = rdCanonicalizer.getSourceType(); - const int loopStep = getDataTypeMAXSIMDLength(vectorType); + const int loopStep = getDataTypeValidSteps(vectorType); auto isStandaloneOp = rdCanonicalizer.getIsStandAloneOp(); auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); @@ -794,13 +860,11 @@ scf::ForOp CanonicalizerVectorOperation::reductionAxisGenerateForLoop( newReadOp->getResult(0), loopState.back()); maybeYieldValue(b, loc, reductionResult); } else { - auto &opPermuationMap = commonUsedData.getOpPermuationMap(); - auto &analysisResults = - commonUsedData.getGroupOpResults()[groupIdx]; + auto &analysisResults = getGroupOpResults()[groupIdx]; - auto &sourceOps = commonUsedData.getMultiRdCanonicalizer()[groupIdx] - .getSourceRelatedOps(); - auto &grpArgs = commonUsedData.getGroupOpIterArgs()[groupIdx]; + auto &sourceOps = + getMultiRdCanonicalizers()[groupIdx].getSourceRelatedOps(); + auto &grpArgs = getGroupOpIterArgs()[groupIdx]; rewriteOperationAsVectorize(b, groupIdx, sourceOps); llvm::DenseMap operandIdxMap; @@ -822,17 +886,15 @@ scf::ForOp CanonicalizerVectorOperation::reductionAxisGenerateForLoop( } if (analysisResults.contains(cur->getResults()[0])) { resultArray.emplace_back(cur->getResults()[0]); - commonUsedData.getMultiRdCanonicalizer()[groupIdx] + getMultiRdCanonicalizers()[groupIdx] .getOriginalOpResults() .insert(cur->getResults()[0]); - commonUsedData.getMultiRdCanonicalizer()[groupIdx] - .getResultIdxMap() - .insert({cur->getResults()[0], resultArray.size() - 1}); + getMultiRdCanonicalizers()[groupIdx].getResultIdxMap().insert( + {cur->getResults()[0], resultArray.size() - 1}); } } - moveOperationsToCurrentForBody(sourceOps, b, loopState, - operandIdxMap, inductionVars, - opPermuationMap); + moveOperationsToCurrentForBody(groupIdx, b, loopState, + operandIdxMap, inductionVars); auto reductionResult = makeArithReduction( b, loc, multireductionOp.getKind(), @@ -852,14 +914,15 @@ scf::ForOp CanonicalizerVectorOperation::reductionAxisGenerateForLoop( return forOp; } -scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( +scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, ValueRange &initArgs, llvm::SmallVector &inductionVars, Value &originalWriteResult) { - auto &rdCanonicalizer = commonUsedData.getMultiRdCanonicalizer()[groupIdx]; + auto &rdCanonicalizer = getMultiRdCanonicalizers()[groupIdx]; auto &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; auto &vectorType = rdCanonicalizer.getSourceType(); auto &accType = rdCanonicalizer.getAccType(); + IRRewriter rewriterOfFunc(func); auto ¶llelAxis = rdCanonicalizer.getParallelAxis(); auto isStandaloneOp = rdCanonicalizer.getIsStandAloneOp(); @@ -881,7 +944,7 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( loc, zero, numIter, forSteps, initArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { inductionVars.emplace_back(iv); - auto &fusionStrategy = commonUsedData.getFusionStrategy(); + auto &fusionStrategy = getFusionStrategy(); auto &opIndexMap = fusionStrategy.getOpGroupIndexMap(); assert(opIndexMap.contains(multiReductionOp) && @@ -889,7 +952,6 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( auto opIndex = opIndexMap[multiReductionOp]; auto &opGroups = fusionStrategy.getOpGroups(); - auto &opPermuationMap = commonUsedData.getOpPermuationMap(); auto opQueue = opGroups[opIndex]; auto multiReductionAcc = multiReductionOp.getAcc(); @@ -1009,16 +1071,12 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( } else { auto prevOp = opQueue.front(); auto postOp = opQueue.back(); - auto &prevOps = - commonUsedData.getMultiRdCanonicalizer()[groupIdx].getPrevOps(); - auto &postOps = - commonUsedData.getMultiRdCanonicalizer()[groupIdx].getPostOps(); + auto &prevOps = getMultiRdCanonicalizers()[groupIdx].getPrevOps(); + auto &postOps = getMultiRdCanonicalizers()[groupIdx].getPostOps(); auto &accRelatedOps = - commonUsedData.getMultiRdCanonicalizer()[groupIdx] - .getAccRelatedOps(); + getMultiRdCanonicalizers()[groupIdx].getAccRelatedOps(); auto &sourceRelatedOps = - commonUsedData.getMultiRdCanonicalizer()[groupIdx] - .getSourceRelatedOps(); + getMultiRdCanonicalizers()[groupIdx].getSourceRelatedOps(); if (mlir::isa(prevOp)) { @@ -1036,15 +1094,15 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( multiReductionOp.getSource().getDefiningOp(), prevOps); rewriteOperationAsVectorize(b, groupIdx, accRelatedOps); - auto &grpArgs = commonUsedData.getGroupOpIterArgs()[groupIdx]; + auto &grpArgs = getGroupOpIterArgs()[groupIdx]; llvm::DenseMap operandIdxMap; for (auto [idx, x] : llvm::enumerate(grpArgs)) { operandIdxMap[x] = idx; } - moveOperationsToCurrentForBody(accRelatedOps, b, loopState, - operandIdxMap, inductionVars, - opPermuationMap); - auto &grpResults = commonUsedData.getGroupOpResults()[groupIdx]; + moveOperationsToCurrentForBody(groupIdx, b, inductionVars, + operandIdxMap, loopState, + accRelatedOps); + auto &grpResults = getGroupOpResults()[groupIdx]; // next for loop llvm::SmallVector iterArgsArray; iterArgsArray.emplace_back(multiReductionAcc); @@ -1069,22 +1127,21 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( inductionVars, originalWriteResult); rewriteOperationAsVectorize(b, groupIdx, postOps); - moveOperationsToCurrentForBody(postOps, b, loopState, + moveOperationsToCurrentForBody(groupIdx, b, loopState, operandIdxMap, inductionVars, - opPermuationMap); + postOps); auto replaceIfFn = [&](OpOperand &use) { return use.getOwner()->getBlock() == nxtFor->getBlock(); }; - rewriter.replaceOpUsesWithIf( + rewriterOfFunc.replaceOpUsesWithIf( multiReductionOp, nxtFor->getResults()[0], replaceIfFn); auto &originalResults = - commonUsedData.getMultiRdCanonicalizer()[groupIdx] - .getOriginalOpResults(); + getMultiRdCanonicalizers()[groupIdx].getOriginalOpResults(); for (auto [idx, x] : llvm::enumerate(originalResults)) { - rewriter.replaceOpUsesWithIf(x.getDefiningOp(), - nxtFor->getResults()[idx + 1], - replaceIfFn); + rewriterOfFunc.replaceOpUsesWithIf( + x.getDefiningOp(), nxtFor->getResults()[idx + 1], + replaceIfFn); } llvm::SmallVector resultsArray; llvm::SmallDenseMap parallelIdxMap; @@ -1156,15 +1213,16 @@ scf::ForOp CanonicalizerVectorOperation::parallelAxisGenerateForLoop( return forOp; } -scf::ForOp CanonicalizerVectorOperation::generateMultiReductionForLoop( - const size_t grpIdx) { - auto &grpArgs = commonUsedData.getGroupOpIterArgs()[grpIdx]; +scf::ForOp +ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { + auto &grpArgs = getGroupOpIterArgs()[grpIdx]; llvm::SmallVector forLoopArgs(grpArgs.begin(), grpArgs.end()); llvm::SmallVector inductionVars; ValueRange initArgs(forLoopArgs); Value originalWriteResult; - auto &rdCanonicalizer = commonUsedData.getMultiRdCanonicalizer()[grpIdx]; + auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; auto &rdResultMap = rdCanonicalizer.getResultIdxMap(); + IRRewriter rewriter(func); OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); @@ -1174,15 +1232,14 @@ scf::ForOp CanonicalizerVectorOperation::generateMultiReductionForLoop( auto replaceIfFn = [&](OpOperand &use) { return use.getOwner()->getBlock() == forOp->getBlock(); }; - for (auto &grpResult : commonUsedData.getGroupOpResults()[grpIdx]) { + for (auto &grpResult : getGroupOpResults()[grpIdx]) { rewriter.replaceOpUsesWithIf(grpResult.getDefiningOp(), forOp->getResults()[rdResultMap[grpResult]], replaceIfFn); } - rewriter.replaceOp( - commonUsedData.getMultiRdCanonicalizer()[grpIdx].getCandidateOps()[0], - forOp); + rewriter.replaceOp(getMultiRdCanonicalizers()[grpIdx].getCandidateOps()[0], + forOp); return forOp; } @@ -1191,39 +1248,6 @@ llvm::SmallVector &SpecialOperationCanonicalizer::getCandidateOps() { return candidateRdOps; }; -void CanonicalizerVectorOperation::getCandidateSpecialOps() { - auto grp = commonUsedData.getFusionStrategy().getOpGroups(); - // avoid seg fault - auto multiRdCanonicalizer = commonUsedData.getMultiRdCanonicalizer(); - multiRdCanonicalizer.clear(); - size_t start = 0; - while (start++ < grp.size()) { - multiRdCanonicalizer.emplace_back(MultiReductionCanonicalizer({})); - } - - auto idxGroup = commonUsedData.getFusionStrategy().getOpGroupIndexMap(); - func->walk([&](Operation *op) { - llvm::TypeSwitch(op) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - auto groupIdx = idxGroup[multiReductionOp]; - multiRdCanonicalizer[groupIdx].getCandidateOps().emplace_back( - multiReductionOp); - }) - .Case([&](vector::ShapeCastOp shapeCastOp) { - // shapeCastOps.insert(shapeCastOp); - // TODO - assert(0); - }) - .Case([&](vector::TransposeOp transposeOp) { - // transposeOps.insert(transposeOp); - // TODO - assert(0); - }) - .Default([&](Operation *) {}); - }); -}; - void MultiReductionCanonicalizer::initReductionAxis() { auto reductionAxisRange = getCandidateOps()[0].getReductionDims().getAsValueRange(); @@ -1280,19 +1304,26 @@ template void addDummyInit(llvm::SmallVector &canonicalizer) { canonicalizer.emplace_back(T({})); }; -void CanonicalizerCommonUsedData::initSpeicalOperationCanonicalizers() { +void CanonicalizerVectorOperation::clearSpecialOperationCanonicalizers() { + getMultiRdCanonicalizers().clear(); + getBroadcastCanonicalizers().clear(); + getTransposeCanonicalizers().clear(); + getShapeCastCanonicalizers().clear(); +} - broadcastCanonicalizers.clear(); - multiRdCanonicalizers.clear(); - transposeCanonicalizers.clear(); - shapeCastCanonicalizers.clear(); - auto &opGroups = fusionStrategy.getOpGroups(); - for (auto &grp : opGroups) { - addDummyInit(multiRdCanonicalizers); - addDummyInit(broadcastCanonicalizers); - addDummyInit(transposeCanonicalizers); - addDummyInit(shapeCastCanonicalizers); +void CanonicalizerVectorOperation::dummyInitSpecialOperation() { + addDummyInit(getMultiRdCanonicalizers()); + addDummyInit(getBroadcastCanonicalizers()); + addDummyInit(getTransposeCanonicalizers()); + addDummyInit(getShapeCastCanonicalizers()); +} + +void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { + clearSpecialOperationCanonicalizers(); + auto &opGroups = getFusionStrategy().getOpGroups(); + for (auto &grp : opGroups) { + dummyInitSpecialOperation(); if (grp.empty()) { continue; } @@ -1301,17 +1332,17 @@ void CanonicalizerCommonUsedData::initSpeicalOperationCanonicalizers() { auto op = tempQ.front(); tempQ.pop(); if (mlir::isa(op)) { - multiRdCanonicalizers.back().getCandidateOps().emplace_back( + getMultiRdCanonicalizers().back().getCandidateOps().emplace_back( mlir::dyn_cast(op)); - multiRdCanonicalizers.back().prepareSpecialOperationInfo(); + getMultiRdCanonicalizers().back().prepareSpecialOperationInfo(); } else if (mlir::isa(op)) { - broadcastCanonicalizers.back().getCandidateOps().emplace_back( + getBroadcastCanonicalizers().back().getCandidateOps().emplace_back( mlir::dyn_cast(op)); } else if (mlir::isa(op)) { - transposeCanonicalizers.back().getCandidateOps().emplace_back( + getTransposeCanonicalizers().back().getCandidateOps().emplace_back( mlir::dyn_cast(op)); } else if (mlir::isa(op)) { - shapeCastCanonicalizers.back().getCandidateOps().emplace_back( + getShapeCastCanonicalizers().back().getCandidateOps().emplace_back( mlir::dyn_cast(op)); } } @@ -1321,9 +1352,9 @@ void CanonicalizerCommonUsedData::initSpeicalOperationCanonicalizers() { LogicalResult CanonicalizerVectorOperation::canonicalizeReductionOperation() { OpBuilder::InsertionGuard guard(rewriter); - commonUsedData.initSpeicalOperationCanonicalizers(); + initSpeicalOperationCanonicalizers(); // traverse all groups - auto &multiRdCanonicalizers = commonUsedData.getMultiRdCanonicalizer(); + auto &multiRdCanonicalizers = getMultiRdCanonicalizers(); for (auto [groupId, rdCanonicalizer] : llvm::enumerate(multiRdCanonicalizers)) { auto &candidateOps = rdCanonicalizer.getCandidateOps(); @@ -1339,10 +1370,14 @@ LogicalResult CanonicalizerVectorOperation::canonicalizeReductionOperation() { void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { // multireduction operation auto result = canonicalizeReductionOperation(); + if (failed(result)) { + LDBG("Failed to canonicalize reduction operation\n"); + assert(0 && "Failed to canonicalize reduction operation"); + } } void CanonicalizerVectorOperation::run() { - auto &fusionStrategy = commonUsedData.getFusionStrategy(); + auto &fusionStrategy = getFusionStrategy(); if (kind == CanonicalizerKind::OperationsGroup) { // 1. Analysis the operation's operands and results // We need to analyze which operation results are needed by other @@ -1351,7 +1386,7 @@ void CanonicalizerVectorOperation::run() { // operation operand as: map(operand, forloop yield result) -> operand = // loop yield result We put all the operation result into this map. - // 1.a. Find what results should be generated by current group for + // 1.a. Find results which should be generated by current group for // using as operands to other operations? // Traverse all operations. If the operand of operations in other groups or @@ -1362,11 +1397,11 @@ void CanonicalizerVectorOperation::run() { // 1.b. What operands are needed to find in the current group, and where // can they be obtained ? - // Thanks to 2.a, we get the result generated by the operations of + // Thanks to 1.a, we get the result generated by the operations of // each group, and this result will use `for loop yield` to generate a // new result. Since the scope of the parent block of mlir is covered // the current operation, the current operation does not need to pass these - // `for loop results` to the `iter args` of the required `for loop`. It + // `for loop results` to the `iterArgs` of the required `for loop`. It // only needs to replace the operand of the current operation with the // corresponding `for loop yield result`. @@ -1381,12 +1416,12 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. - analysisGroupOperaionOperandsResults(); + analysisGroupOperaion(); // Speical Operation Canonicalization canonicalizeSpecialOperation(); - // 2.Generate vectorized IR for each operation group - for (auto [idx, grp] : llvm::enumerate(fusionStrategy.getOpGroups())) { + // 2.Generate vectorized IR for each operation group + for (size_t idx = 0; idx < fusionStrategy.getOpGroups().size(); ++idx) { generateGroupOpVectorizedIR(idx); } @@ -1394,7 +1429,7 @@ void CanonicalizerVectorOperation::run() { DominanceInfo domInfo; eliminateCommonSubExpressions(rewriter, domInfo, func); } else { - // TODO: need to add directly canonicalize operations + // TODO: need to add directly canonicalize operations logic // generateGroupOpVectorizedIR(idx, grp, fusionStrategy.opGroupIndexMap); } } @@ -1425,7 +1460,7 @@ void CanonicalizerVectorOperation::run() { } // -void checkAndSetOperand( +void setOperationCorrectOperand( Operation *op, const ValueRange &iterArgs, const llvm::DenseMap &operandIdxMap, const llvm::SmallVector &inductionVars, @@ -1451,19 +1486,18 @@ void checkAndSetOperand( } } -scf::ForOp constructNestedForOp( - OpBuilder &b, const Location &loc, const ValueRange &iterArgs, - const VectorType &type, const llvm::ArrayRef &dims, size_t idx, - std::queue &queue, const llvm::SetVector &resultSet, +scf::ForOp ForLoopGenerator::constructNestedForOp( + const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, + const Location &loc, const ValueRange &iterArgs, const VectorType &type, + const llvm::ArrayRef &dims, llvm::SmallVector &inductionVars, - const llvm::DenseMap &operandIdxMap, - const llvm::DenseMap &opPermuationMap) { - const int loop_step = getDataTypeMAXSIMDLength(type); + const llvm::DenseMap &operandIdxMap) { + const int loop_step = getDataTypeValidSteps(type); // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); - auto forSteps = - makeIndexArithConstantOp(b, loc, idx == dims.size() - 1 ? loop_step : 1); - auto numIter = makeIndexArithConstantOp(b, loc, dims[idx]); + auto forSteps = makeIndexArithConstantOp( + b, loc, forDimIdx == dims.size() - 1 ? loop_step : 1); + auto numIter = makeIndexArithConstantOp(b, loc, dims[forDimIdx]); // Create a loop and move vectorized operation into loops. auto forOp = b.create( @@ -1472,15 +1506,16 @@ scf::ForOp constructNestedForOp( inductionVars.emplace_back(iv); // inner most body of the loop - if (idx == dims.size() - 1) { - moveOperationsToCurrentForBody(queue, b, loopState, operandIdxMap, - inductionVars, opPermuationMap); + if (forDimIdx == dims.size() - 1) { + moveOperationsToCurrentForBody(groupIdx, b, inductionVars, + operandIdxMap, loopState); + auto &resultSet = getGroupOpResults()[groupIdx]; maybeYieldValue(b, loc, resultSet.getArrayRef()); } else { // outter loop - auto nxtFor = constructNestedForOp( - b, loc, loopState, type, dims, idx + 1, queue, resultSet, - inductionVars, operandIdxMap, opPermuationMap); + auto nxtFor = + constructNestedForOp(forDimIdx + 1, groupIdx, b, loc, loopState, + type, dims, inductionVars, operandIdxMap); maybeYieldValue(b, loc, nxtFor->getResults()); } }); @@ -1493,13 +1528,9 @@ bool isCompatibleVectorType(Operation *op1, Operation *op2) { if (failed(type1) || failed(type2)) { return false; } - auto isReadOrWrite = [](Operation *op) { - return mlir::isa(op) or - mlir::isa(op); - }; auto sp1 = type1.value(); auto sp2 = type2.value(); - if (isReadOrWrite(op1) or isReadOrWrite(op2)) { + if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { if (sp1.getRank() != sp2.getRank()) { return false; } @@ -1509,9 +1540,35 @@ bool isCompatibleVectorType(Operation *op1, Operation *op2) { } } } + if (sp1.getRank() != sp2.getRank()) { + return false; + } + bool isCompatible = true; + // from front to back + for (long i = 0; i < sp1.getRank(); i++) { + if (sp1.getDimSize(i) != sp2.getDimSize(i)) { + isCompatible = false; + break; + } + } + + return isCompatible; +} - auto min_rank = std::min(sp1.getRank(), sp2.getRank()) - 1; +bool isPartialCompatible(Operation *op1, Operation *op2) { + auto type1 = getOperationVectorType(op1); + auto type2 = getOperationVectorType(op2); + if (failed(type1) || failed(type2)) { + return false; + } + auto sp1 = type1.value(); + auto sp2 = type2.value(); + // must be total same + if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { + return false; + } bool isCompatible = true; + auto min_rank = std::min(sp1.getRank(), sp2.getRank()); // from front to back for (long i = 0; i < min_rank; i++) { if (sp1.getDimSize(i) != sp2.getDimSize(i)) { @@ -1658,7 +1715,8 @@ bool hasDataDependency(Operation *op1, Operation *op2) { if (!isSpecialOp(op1)) { return hasDataDependency(op2, op1); } - // TODO: remove this condition in the future + // TODO: Remove this condition to support special operation fusion in the + // future if (disableSpecialOp) { return true; } @@ -1776,6 +1834,9 @@ bool isNeedNewGroup(llvm::SmallVector, 8> &opGroups, // previous operation vector type is not compatible with current operation if (!isCompatibleVectorType(prevOp, op)) { + // TODO: Support partial compatible operation fusion + if (isPartialCompatible(prevOp, op)) { + } return true; } } @@ -1793,8 +1854,6 @@ void addOperationToGroup( opGroupIndexMap[op] = opGroups.size() - 1; } -bool isInitOperation(Operation *op) { return mlir::isa(op); } - // We classify the operations we are interested in after filtering. Operations // of in the same group have no data dependencies. Those operations can generate // a same outter for loop. @@ -1806,14 +1865,6 @@ void VectorFusionStrategy::classifyOperations() { func->walk([&](Operation *op) { if (filterOperation(op)) { addOperationToGroup(opGroups, opGroupIndexMap, op); - // update init operation - } - while (ignoreInitOperations.size() < opGroups.size()) { - ignoreInitOperations.emplace_back(std::queue()); - } - // some init operations need to ignore - if (isInitOperation(op)) { - ignoreInitOperations.back().push(op); } }); } @@ -1885,9 +1936,9 @@ void setOperationOperandResult( } }; -void CanonicalizerVectorOperation::createNewConstantOp( +void ForLoopGenerator::createNewConstantOp( Operation *srcOp, vector::TransferWriteOp *transferWriteOp) { - auto &opPermuationMap = commonUsedData.getOpPermuationMap(); + auto &opPermuationMap = getOpPermuationMap(); IRRewriter srcWriter(srcOp); auto newOperandType = getVectorzedType(mlir::cast(srcOp)); auto srcConstantOp = mlir::dyn_cast(srcOp); @@ -1904,10 +1955,7 @@ void CanonicalizerVectorOperation::createNewConstantOp( srcOp->getLoc(), IntegerAttr::get(newOperandType, valueType.getSplatValue())); } - } else { - // write original vector into tensor - // then we transfer_read from the tensor assert(0 && "Not support non-splat constant value."); } } else { @@ -1927,16 +1975,14 @@ void CanonicalizerVectorOperation::createNewConstantOp( } /// Rewrite the operations in the group to vectorized form. -void CanonicalizerVectorOperation::rewriteOperationAsVectorize( +void ForLoopGenerator::rewriteOperationAsVectorize( OpBuilder &rewriter, size_t groupId, const std::queue &queue) { auto &groupOps = - queue.empty() ? commonUsedData.getFusionStrategy().getOpGroups()[groupId] - : queue; - auto &opMap = commonUsedData.getFusionStrategy().getOpGroupIndexMap(); - auto &opPermuationMap = commonUsedData.getOpPermuationMap(); + queue.empty() ? getFusionStrategy().getOpGroups()[groupId] : queue; + auto &opMap = getFusionStrategy().getOpGroupIndexMap(); + auto &opPermuationMap = getOpPermuationMap(); std::queue transformQueue(groupOps); - auto groupSteps = - commonUsedData.getFusionStrategy().getGroupMaxSteps()[groupId]; + auto groupSteps = getFusionStrategy().getGroupMaxSteps()[groupId]; while (!transformQueue.empty()) { auto op = transformQueue.front(); @@ -1952,10 +1998,8 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( if (mlir::isa(srcOp)) { createNewConstantOp(srcOp, &transferWriteOp); } else { - transferWriteOp->dump(); opPermuationMap.insert( {transferWriteOp, transferWriteOp.getPermutationMap()}); - newOperandType.dump(); transferWriteOp->getOperand(0).setType(newOperandType); setOpVectorizationPermutationMap( @@ -1988,9 +2032,9 @@ void CanonicalizerVectorOperation::rewriteOperationAsVectorize( "It should not appear this operation."); return failure(); }) - .Case([&](arith::ExtFOp extFop) { - auto newOperandType = getVectorzedType(extFop, groupSteps); - extFop->getResult(0).setType(newOperandType); + .Case([&](arith::ExtFOp extfOp) { + auto newOperandType = getVectorzedType(extfOp, groupSteps); + extfOp->getResult(0).setType(newOperandType); return success(); }) .Default([&](Operation *op) { @@ -2057,15 +2101,13 @@ void updateOpOperandResultInGroups( void VectorFusionStrategy::run() { classifyOperations(); } -void CanonicalizerVectorOperation::generateEmptyTensorAndWrite( +void VectorOperationAnalysizer::generateEmptyTensorAndWrite( Operation *sourceOp, llvm::DenseMap> &srcOpCanoniclizedMap) { - auto &commonUsedData = getCommonUsedData(); - auto &opGroups = commonUsedData.getFusionStrategy().getOpGroups(); - auto &opGroupIndexMap = - commonUsedData.getFusionStrategy().getOpGroupIndexMap(); - auto &groupOpIterArgs = commonUsedData.getGroupOpIterArgs(); - auto &groupOpResults = commonUsedData.getGroupOpResults(); + auto &opGroups = getFusionStrategy().getOpGroups(); + auto &opGroupIndexMap = getFusionStrategy().getOpGroupIndexMap(); + auto &groupOpIterArgs = getGroupOpIterArgs(); + auto &groupOpResults = getGroupOpResults(); auto sourceOpGid = opGroupIndexMap[sourceOp]; auto [resultTensor, result] = canonicalizeSourceOperation(sourceOp); @@ -2076,9 +2118,9 @@ void CanonicalizerVectorOperation::generateEmptyTensorAndWrite( groupOpResults[sourceOpGid].insert(result); } -void CanonicalizerVectorOperation::analysisEmptyGroupAndMaxSteps() { - auto &groupOpResults = commonUsedData.getGroupOpResults(); - auto &opGroups = commonUsedData.getFusionStrategy().getOpGroups(); +void VectorOperationAnalysizer::analysisEmptyGroupAndMaxSteps() { + auto &groupOpResults = getGroupOpResults(); + auto &opGroups = getFusionStrategy().getOpGroups(); // If the group operations do not have result need to be returned, these are // useless code. @@ -2088,7 +2130,7 @@ void CanonicalizerVectorOperation::analysisEmptyGroupAndMaxSteps() { } uint32_t steps = std::numeric_limits::max(); - auto &grpSteps = commonUsedData.getFusionStrategy().getGroupMaxSteps(); + auto &grpSteps = getFusionStrategy().getGroupMaxSteps(); while (idx >= grpSteps.size()) { grpSteps.emplace_back(steps); } @@ -2112,15 +2154,13 @@ void CanonicalizerVectorOperation::analysisEmptyGroupAndMaxSteps() { // analysis operation result of current group whether needed by other // operation which out of current group -void CanonicalizerVectorOperation::analysisGroupOperationResults() { +void VectorOperationAnalysizer::analysisGroupOperationResults() { llvm::DenseMap> srcOpCanoniclizedMap; llvm::DenseSet movedOperationSet; - auto &commonUsedData = getCommonUsedData(); - auto &opGroups = commonUsedData.getFusionStrategy().getOpGroups(); - auto &opGroupIndexMap = - commonUsedData.getFusionStrategy().getOpGroupIndexMap(); - auto &groupOpIterArgs = commonUsedData.getGroupOpIterArgs(); - auto &groupOpResults = commonUsedData.getGroupOpResults(); + auto &opGroups = getFusionStrategy().getOpGroups(); + auto &opGroupIndexMap = getFusionStrategy().getOpGroupIndexMap(); + auto &groupOpIterArgs = getGroupOpIterArgs(); + auto &groupOpResults = getGroupOpResults(); func.walk([&](Operation *op) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { auto sourceOp = opd.getDefiningOp(); @@ -2177,16 +2217,15 @@ void CanonicalizerVectorOperation::analysisGroupOperationResults() { LDBG("Complete analysis group operation results\n"); } -void CanonicalizerVectorOperation::analysisGroupOperaionOperandsResults() { +void VectorOperationAnalysizer::analysisGroupOperaion() { // Results analysisGroupOperationResults(); } -mlir::FailureOr generateVectorizedForLoop( - IRRewriter &rewriter, const llvm::SetVector &resultSet, - const llvm::SetVector &dstOperandSet, const VectorType &vectorType, - std::queue &queue, - const llvm::DenseMap &opPermuationMap) { +mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( + const size_t groupId, IRRewriter &rewriter, const VectorType &vectorType) { + auto &resultSet = getGroupOpResults(); + auto &dstOperandSet = getGroupOpIterArgs()[groupId]; assert(!resultSet.empty() && "Expected non-empty value"); // prepare for loop iterargs llvm::SmallVector operands; @@ -2200,8 +2239,8 @@ mlir::FailureOr generateVectorizedForLoop( llvm::SmallVector inductionVars; // generate for loop auto forOp = constructNestedForOp( - rewriter, rewriter.getUnknownLoc(), iterArgs, vectorType, shapes, 0, - queue, resultSet, inductionVars, operandIdxMap, opPermuationMap); + 0, groupId, rewriter, rewriter.getUnknownLoc(), iterArgs, vectorType, + shapes, inductionVars, operandIdxMap); return forOp; } @@ -2225,21 +2264,20 @@ void updateLoopResultUses(llvm::SetVector &opResults, } } -bool CanonicalizerVectorOperation::isGroupHasSpecialOperation( +bool CanonicalizerCommonUsedData::isGroupHasSpecialOperation( const size_t grpIdx) { - auto &rdCanonicalizer = commonUsedData.getMultiRdCanonicalizer()[grpIdx]; - auto &bcCanonicalizer = commonUsedData.getBroadcastCanonicalizer()[grpIdx]; - auto &tpCanonicalizer = commonUsedData.getTransposeCanonicalizer()[grpIdx]; - auto &shapeCastCanonicalizer = - commonUsedData.getShapeCastCanonicalizer()[grpIdx]; + auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; + auto &bcCanonicalizer = getBroadcastCanonicalizers()[grpIdx]; + auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; + auto &shapeCastCanonicalizer = getShapeCastCanonicalizers()[grpIdx]; return !rdCanonicalizer.getCandidateOps().empty() or !bcCanonicalizer.getCandidateOps().empty() or !tpCanonicalizer.getCandidateOps().empty() or !shapeCastCanonicalizer.getCandidateOps().empty(); } -void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { - auto &grp = commonUsedData.getFusionStrategy().getOpGroups()[idx]; +void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { + auto &grp = getFusionStrategy().getOpGroups()[idx]; if (grp.empty()) { LDBG("Current operation Group is empty."); return; @@ -2248,9 +2286,7 @@ void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { if (isGroupHasSpecialOperation(idx)) { return; } - auto &groupOpResults = commonUsedData.getGroupOpResults(); - auto &opPermuationMap = commonUsedData.getOpPermuationMap(); - auto &groupOpIterArgs = commonUsedData.getGroupOpIterArgs(); + auto &groupOpResults = getGroupOpResults(); auto getType = getOperationVectorType(grp.front()); if (failed(getType)) { LDBG("Failed to get vector type for operation: " << *grp.front() << "\n"); @@ -2262,9 +2298,7 @@ void CanonicalizerVectorOperation::generateGroupOpVectorizedIR(const int idx) { // 1. Rewrite operation as vectorized form // 2. Generate loop rewriteOperationAsVectorize(rewriter, idx); - auto forOp = generateVectorizedForLoop(rewriter, groupOpResults[idx], - groupOpIterArgs[idx], opShapes, grp, - opPermuationMap); + auto forOp = generateVectorizedForLoop(idx, rewriter, opShapes); // special operation do not need to change anything if (failed(forOp)) { return; @@ -2283,10 +2317,16 @@ struct CPUPhysicalRegisterPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); auto func = getOperation(); + if (hasNotSupportOperation(&func)) { + LDBG("Not support operation appears in current function."); + return; + } // canonicalize vector operation, default use vector-based fusion strategy. CanonicalizerVectorOperation canonicalizer( func, CanonicalizerKind::OperationsGroup); canonicalizer.run(); + // patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; } // namespace diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 64534ebba..04b904f7c 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -1,16 +1,46 @@ -// RUN: gc-opt --split-input-file --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all -- %s +// RUN: gc-opt %s --split-input-file --lower-to-tile-vector --CPU-physical-register-pass | FileCheck %s -// CHECK-LABEL: func @add_tensor +// CHECK-LABEL: func @add_tensor_test0 func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { + // CHECK: %[[C4096:.*]] = arith.constant 4096 : index + // CHECK: %[[C16:.*]] = arith.constant 16 : index + // CHECK: %[[C11008:.*]] = arith.constant 11008 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[TENSOR0:.*]] = tensor.empty() : tensor<11008x4096xf32> + // CHECK: scf.for + // CHECK: scf.for + // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> + // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> + // CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> + // CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[READ1]] : vector<16xf32> + // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<11008x4096xf32> %0 = tensor.empty() : tensor<11008x4096xf32> %1 = linalg.add ins(%arg0, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> %2 = linalg.add ins(%1, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> return %2 : tensor<11008x4096xf32> } +// CHECK-LABEL: func @fc_relu func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) -> tensor<512x512xf32> { + // CHECK: scf.for + // CHECK: scf.for + // CHECK: scf.for + // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512x512xf32>, vector<16xf32> + // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512x512xf32>, vector<16xf32> + // CHECK: %[[MULF0:.*]] = arith.mulf %[[READ0]], %[[READ1]] : vector<16xf32> + // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512x512xf32> + // CHECK-DAG: vector.multi_reduction + // CHECK: scf.for + // CHECK: scf.for + // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> + // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> + // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> + // CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], {{.*}} : vector<16xf32> + // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512xf32> // Matrix-matrix multiplication. %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> @@ -29,96 +59,28 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, func.return %relued : tensor<512x512xf32> } -func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { - %0 = tensor.empty() : tensor<16x64xf32> - %reduce = linalg.reduce - ins(%arg0:tensor<16x32x64xf32>) - outs(%0:tensor<16x64xf32>) - dimensions = [1] - (%in: f32, %out: f32) { - %1 = arith.addf %out, %in: f32 - linalg.yield %1: f32 - } - %2 = tensor.expand_shape %reduce [[0],[1, 2]] output_shape [16, 1, 64] : tensor<16x64xf32> into tensor<16x1x64xf32> - return %2 : tensor<16x1x64xf32> -} - -#map = affine_map<(d0) -> (d0 * 64)> -#map1 = affine_map<(d0) -> (d0 * 128)> -#map2 = affine_map<(d0) -> (d0 floordiv 16)> -#map3 = affine_map<(d0) -> (d0 floordiv 32)> -#map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> -#map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> - func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { - %c32 = arith.constant 32 : index - %c512 = arith.constant 512 : index - %c128 = arith.constant 128 : index - %c64 = arith.constant 64 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x256xbf16> - %1 = tensor.empty() : tensor<512x256xbf16> - %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { - %3 = affine.apply #map(%arg3) - %4 = affine.apply #map1(%arg4) - %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> - %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> - %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> - %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> - %11 = affine.apply #map2(%arg28) - %12 = affine.apply #map3(%arg24) - %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> - %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> - %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> - %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> - %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> - %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %15 = affine.apply #map4(%arg12, %arg24, %arg4) - %16 = affine.apply #map5(%arg8, %arg20, %arg3) - %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> - %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> - %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] - %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.forall.in_parallel { - tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - } - } - return %2#2 : tensor<128x256xbf16> - } - +// CHECK-LABEL: func @matmul_add func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { + // CHECK: vector.broadcast + // CHECK: vector.transpose + // CHECK: vector.broadcast + // CHECK: vector.transpose + // CHECK: scf.for + // CHECK: scf.for + // CHECK: scf.for + // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<32x32x12288xf16>, vector<16xf16> + // CHECK: %[[EXTF0:.*]] = arith.extf %[[READ0]] : vector<16xf16> to vector<16xf32> + // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} : tensor<32x32x12288xf16>, vector<16xf16> + // CHECK: %[[EXTF1:.*]] = arith.extf %[[READ1]] : vector<16xf16> to vector<16xf32> + // CHECK: %[[MULF0:.*]] = arith.mulf %[[EXTF0]], %[[EXTF1]] : vector<16xf32> + // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32x12288xf32> + // CHECK-DAG: vector.multi_reduction + // CHECK: scf.for + // CHECK: scf.for + // CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<8192x16384xf32>, vector<16xf32> + // CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<32x32xf32>, vector<16xf32> + // CHECK: %[[ADD0:.*]] = arith.addf %[[READ3]], %[[READ2]] : vector<16xf32> + // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<128x128xf32> %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> %1 = tensor.empty() : tensor<8192x16384xf32> %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { From 26b2ab8cd2cfaa4c778933d07f100a5f6d413c48 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 11 Jul 2024 16:17:04 +0800 Subject: [PATCH 15/66] simplify code --- include/gc/Transforms/Passes.h | 1 - include/gc/Transforms/TilingVector.h | 92 +++-- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 324 ++++++++++-------- .../gc/transforms/cpu-vetor-distribution.mlir | 74 ++++ 4 files changed, 319 insertions(+), 172 deletions(-) diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index 39ecaa485..ce182eb42 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -9,7 +9,6 @@ #ifndef GC_PASSES_H #define GC_PASSES_H -#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index d8abce0fe..4a85002b4 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -1,5 +1,4 @@ -//===- TilingVector.h - Graph Compiler passes -------------------------*- C++ -//-*-===// +//===- TilingVector.h - Tiling large vector to small vector ---*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -28,27 +27,20 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/APFloat.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include -#include #include -#include -#include #include #include #include -#include #include namespace mlir { namespace gc { @@ -61,27 +53,46 @@ void setOperationCorrectOperand( const llvm::SmallVector &inductionVars, const llvm::DenseMap &opPermuationMap); -// 1. Classify operaions: -// classify the operations into : -// a. reorder, transpose. Reorder(or transpose) dim may bring data -// dependency. -// b. elemenwise. Those operations can be fused into a common for loop. -// c. broadcast. Need to analysis broadcast dim and the data -// dependency. -// d. reduction. Need to analysis broadcast dim and the -// data dependency. -// Same group operations have no data dependencies. They can be fused into a -// common for loop body. - -// Using queue to store the operation order. In order to ensure that -// subsequent moves to the operation will not cause semantic changes. +struct HardWareInfo { + bool favx512f = true; + bool favx2 = true; +}; + +/// VectorType conversion helper class +class TypeHelper { +private: + HardWareInfo HWInfo; + +public: + void setHardWareInfo(HardWareInfo &info) { HWInfo = info; } + int getDataTypeValidSteps(const VectorType &type); + int generateValidSteps(int steps, VectorType type); + int getDataTypeMAXSIMDLength(const VectorType &type); + VectorType getVectorzedType(Operation *op, uint32_t loopStep = 0); +}; + +/// Operation fusion strategy class. +/// 1. Classify operaions: +/// classify the operations into : +/// a. reorder, transpose. Reorder(or transpose) dim may bring data +/// dependency. +/// b. elemenwise. Those operations can be fused into a common for loop. +/// c. broadcast. Need to analysis broadcast dim and the data +/// dependency. +/// d. reduction. Need to analysis broadcast dim and the +/// data dependency. +/// Same group operations have no data dependencies. They can be fused into a +/// common for loop body. + +/// Using queue to store the operation order. In order to ensure that +/// subsequent moves to the operation will not cause semantic changes. class VectorFusionStrategy { private: llvm::SmallVector, 8> opGroups; llvm::SmallVector groupMaxSteps; - // query current operation in which group, return group index + /// query current operation in which group, return group index llvm::DenseMap opGroupIndexMap; - // can fused into prev operation which axis position + /// can fused into prev operation which axis position llvm::DenseMap opAnchorPos; func::FuncOp func; @@ -102,7 +113,14 @@ class VectorFusionStrategy { void classifyOperations(); - // run the vector fusion strategy + /// Check whether the operation can fuse with previous operation + bool isNeedNewGroup(Operation *op); + + /// Add Operation \p op into current last group or a new Group + /// \p op must has valid value, can't be nullptr + void addOperationToGroup(Operation *op); + + /// run the vector-based fusion strategy void run(); }; @@ -116,10 +134,12 @@ template class SpecialOperationCanonicalizer { SpecialOperationCanonicalizer() = default; SpecialOperationCanonicalizer(const llvm::SmallVector &candidateRdOps) : candidateRdOps(candidateRdOps) {} + virtual ~SpecialOperationCanonicalizer() {} llvm::SmallVector &getCandidateOps(); virtual void prepareSpecialOperationInfo() = 0; }; +enum class MultiReduceOpAxisKind { Reduction, Parallel }; class MultiReductionCanonicalizer : public SpecialOperationCanonicalizer { private: @@ -140,6 +160,7 @@ class MultiReductionCanonicalizer isStandaloneOp = candidateRdOps.size() == 1; prepareSpecialOperationInfo(); }; + virtual ~MultiReductionCanonicalizer(){}; int64_t getTypeRank(); void getReductionAxisAndParallelAxis(); bool hasLastDimReduction(); @@ -157,7 +178,7 @@ class MultiReductionCanonicalizer VectorType &getAccType() { return accType; }; llvm::SmallDenseMap &getResultIdxMap() { return resultIdxMap; } void setResultIdxMap(const llvm::SmallDenseMap &map) { - resultIdxMap = std::move(map); + resultIdxMap = map; } void prepareSpecialOperationInfo() override; }; @@ -169,6 +190,7 @@ class BroadcastCanonicalizer BroadcastCanonicalizer( const llvm::SmallVector &candidateBcOps) : SpecialOperationCanonicalizer(candidateBcOps){}; + virtual ~BroadcastCanonicalizer() {} void prepareSpecialOperationInfo() override {} }; @@ -179,6 +201,7 @@ class TransposeCanonicalizer TransposeCanonicalizer( const llvm::SmallVector &candidateTpOps) : SpecialOperationCanonicalizer(candidateTpOps){}; + virtual ~TransposeCanonicalizer() {} void prepareSpecialOperationInfo() override {} }; @@ -189,10 +212,11 @@ class ShapeCastCanonicalizer ShapeCastCanonicalizer( const llvm::SmallVector &candidateScOps) : SpecialOperationCanonicalizer(candidateScOps){}; + virtual ~ShapeCastCanonicalizer() {} void prepareSpecialOperationInfo() override {} }; -class CanonicalizerCommonUsedData { +class CanonicalizerCommonUsedData : virtual public TypeHelper { private: VectorFusionStrategy fusionStrategy; // analysis the operation's operands and results @@ -308,7 +332,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const size_t groupIdx, OpBuilder &b, const llvm::SmallVector &inductionVars, const llvm::DenseMap &operandIdxMap, - const ValueRange &loopState, const std::queue &queue = {}); + const ValueRange &loopState, std::queue &queue); // multireduction forloop methods scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); @@ -322,6 +346,12 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { reductionAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, ValueRange &initArgs, llvm::SmallVector &inductionVars); + + vector::TransferReadOp cloneReductionTransferRead( + Value &source, OpBuilder &b, IRMapping &readMap, + const llvm::SmallVector ¶llelAxis, + llvm::SmallVector &inductionVars, bool lastDimReduction, + MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel); }; class VectorOperationAnalysizer : virtual public CanonicalizerCommonUsedData { @@ -349,10 +379,12 @@ class CanonicalizerVectorOperation : virtual public VectorOperationAnalysizer, public: CanonicalizerVectorOperation( func::FuncOp func, - CanonicalizerKind kind = CanonicalizerKind::OperationsGroup) + CanonicalizerKind kind = CanonicalizerKind::OperationsGroup, + HardWareInfo hwInfo = {}) : func(func), rewriter(func), kind(kind) { setAnalysisFunc(func); setGeneratorFunc(func); + setHardWareInfo(hwInfo); // vector operation fusion if (kind == CanonicalizerKind::OperationsGroup) { auto fusionStrategy = VectorFusionStrategy(func); diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 2d77c2b63..64044d406 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -17,29 +17,30 @@ namespace { #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +#define ARITH_CAST_OPERATIONS \ + arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::BitcastOp, \ + arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp + // TODO: remove it in the future -bool disableSpecialOp = true; - -struct HardwareInfo { - bool favx512f = true; - bool favx2 = true; -} HW; - -// void printGroupOps(llvm::SmallVector, 8> &opGroups) { -// for (auto grp : opGroups) { -// if (grp.empty()) { -// continue; -// } -// std::cout << "__________________ group start_____________" << std::endl; -// std::queue tmpQ(grp); -// while (!tmpQ.empty()) { -// auto cur = tmpQ.front(); -// tmpQ.pop(); -// cur->dump(); -// } -// std::cout << "__________________ group end_____________" << std::endl; -// } -// } +bool disableSpecialOp = false; + +void printGroupOps(llvm::SmallVector, 8> &opGroups) { + for (auto grp : opGroups) { + if (grp.empty()) { + continue; + } + std::cout << "__________________ group start_____________" << std::endl; + std::queue tmpQ(grp); + while (!tmpQ.empty()) { + auto cur = tmpQ.front(); + tmpQ.pop(); + cur->dump(); + } + std::cout << "__________________ group end_____________" << std::endl; + std::cout << std::endl; + } +} bool isSpecialOp(Operation *op) { return llvm::isa(op) || @@ -98,7 +99,7 @@ int get_nearest_vector_step(const int step) { return (1 << (nbits - 1)) == step ? step : (1 << nbits); } -int generateValidSteps(int steps, VectorType type) { +int TypeHelper::generateValidSteps(int steps, VectorType type) { return type.getShape().back() >= steps ? (steps > 16 ? 16 : steps) : get_nearest_vector_step(type.getShape().back()); @@ -110,13 +111,13 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { mlir::dyn_cast(expr).getPosition() == rank - 1; } -[[nodiscard]] int getDataTypeValidSteps(const VectorType &type) { +[[nodiscard]] int TypeHelper::getDataTypeValidSteps(const VectorType &type) { auto typebits = type.getElementTypeBitWidth(); const int favx512bits = 512; const int favx2bits = 256; - if (HW.favx512f) { + if (HWInfo.favx512f) { return generateValidSteps(favx512bits / typebits, type); - } else if (HW.favx2) { + } else if (HWInfo.favx2) { return generateValidSteps(favx2bits / typebits, type); } else { // invalid @@ -127,15 +128,13 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { } // Get the maximum number of current data types that a register can hold -[[nodiscard]] int getDataTypeMAXSIMDLength(const VectorType &type) { +[[nodiscard]] int TypeHelper::getDataTypeMAXSIMDLength(const VectorType &type) { auto typebits = type.getElementTypeBitWidth(); const int favx512bits = 512; const int favx2bits = 256; - if (HW.favx512f) { - // return generateValidSteps(favx512bits / typebits, type); + if (HWInfo.favx512f) { return favx512bits / typebits; - } else if (HW.favx2) { - // return generateValidSteps(favx2bits / typebits, type); + } else if (HWInfo.favx2) { return favx2bits / typebits; } else { // invalid @@ -192,7 +191,7 @@ mlir::FailureOr getOperationVectorType(Operation *op) { return ret; } -VectorType getVectorzedType(Operation *op, uint32_t loop_step = 0) { +VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loop_step) { // Check that the operation type can be broken // down into a loop. auto baseType = getOperationVectorType(op); @@ -732,7 +731,6 @@ void classifySourceRelatedOps(std::queue &accRelatedOps, } } -enum class MultiReduceOpAxisKind { Reduction, Parallel }; void updateReduceReadWriteOperationOperand( const llvm::SmallVector &inductionVars, const llvm::SmallVector ¶llelAxis, Operation *op, @@ -747,11 +745,11 @@ void updateReduceReadWriteOperationOperand( } } -vector::TransferReadOp cloneReductionTransferRead( +vector::TransferReadOp ForLoopGenerator::cloneReductionTransferRead( Value &source, OpBuilder &b, IRMapping &readMap, const llvm::SmallVector ¶llelAxis, llvm::SmallVector &inductionVars, bool lastDimReduction, - MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel) { + MultiReduceOpAxisKind rdKind) { IRRewriter rewriter(b); auto readOp = mlir::dyn_cast(source.getDefiningOp()); assert(readOp && " Not transfer_read operation. Current multireduction " @@ -806,9 +804,7 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( const size_t groupIdx, OpBuilder &b, const llvm::SmallVector &inductionVars, const llvm::DenseMap &operandIdxMap, - const ValueRange &loopState, const std::queue &queue) { - std::queue opQueue = - queue.empty() ? getFusionStrategy().getOpGroups()[groupIdx] : queue; + const ValueRange &loopState, std::queue &opQueue) { auto &opPermuationMap = getOpPermuationMap(); while (!opQueue.empty()) { auto x = opQueue.front(); @@ -820,6 +816,23 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( } } +bool hasOtherOperations(const std::queue &opQ, + const Operation *multiReductionOp) { + bool res = false; + if (!opQ.empty()) { + std::queue tempQ(opQ); + while (!tempQ.empty()) { + auto cur = tempQ.front(); + tempQ.pop(); + if (!isReadOrWriteOperation(cur) and cur != multiReductionOp) { + res = true; + break; + } + } + } + return res; +}; + scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, ValueRange &initArgs, llvm::SmallVector &inductionVars) { @@ -893,8 +906,8 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( {cur->getResults()[0], resultArray.size() - 1}); } } - moveOperationsToCurrentForBody(groupIdx, b, loopState, - operandIdxMap, inductionVars); + moveOperationsToCurrentForBody(groupIdx, b, inductionVars, + operandIdxMap, loopState, sourceOps); auto reductionResult = makeArithReduction( b, loc, multireductionOp.getKind(), @@ -914,6 +927,8 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( return forOp; } +// Generate for loop for parallel axis of `vector.multi_reduction`. +// This function also call reduction axis for loop scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, ValueRange &initArgs, llvm::SmallVector &inductionVars, @@ -934,6 +949,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( // last dim reduction need to a generate dim=16 loop int dimSize = 0; if (parallelIdx == parallelAxis.size()) { + // TODO: need to consider data type lanes dimSize = 16; } else { dimSize = vectorType.getShape()[parallelAxis[parallelIdx]]; @@ -1035,6 +1051,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( loc, DenseElementsAttr::get(accType, {initValueAttr})); ValueRange newIterArgs(accVal); + auto nxtFor = reductionAxisGenerateForLoop( b, groupIdx, 0, newIterArgs, inductionVars); @@ -1069,66 +1086,69 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( maybeYieldValue(b, loc, newAccWriteOp->getResults()); } else { - auto prevOp = opQueue.front(); - auto postOp = opQueue.back(); + auto &prevOps = getMultiRdCanonicalizers()[groupIdx].getPrevOps(); auto &postOps = getMultiRdCanonicalizers()[groupIdx].getPostOps(); auto &accRelatedOps = getMultiRdCanonicalizers()[groupIdx].getAccRelatedOps(); auto &sourceRelatedOps = getMultiRdCanonicalizers()[groupIdx].getSourceRelatedOps(); + // prevOp + reduction op + postOp + // reduction op + postOp + getPrevOps(prevOps, opQueue, multiReductionOp); + getPostOps(postOps, opQueue, multiReductionOp); + bool hasPrevOps = hasOtherOperations(prevOps, multiReductionOp); + bool hasPostOps = hasOtherOperations(postOps, multiReductionOp); - if (mlir::isa(prevOp)) { + if (hasPostOps and !hasPrevOps) { + // multi_reduction + postOp } else { - if (mlir::isa(postOp)) { - // prevOp + reduction op - } else { - // prevOp + reduction op + postOp - // reduction op + postOp - getPrevOps(prevOps, opQueue, multiReductionOp); - getPostOps(postOps, opQueue, multiReductionOp); - // analysis acc related operation - classifySourceRelatedOps( - accRelatedOps, sourceRelatedOps, - multiReductionOp.getSource().getDefiningOp(), prevOps); - - rewriteOperationAsVectorize(b, groupIdx, accRelatedOps); - auto &grpArgs = getGroupOpIterArgs()[groupIdx]; - llvm::DenseMap operandIdxMap; - for (auto [idx, x] : llvm::enumerate(grpArgs)) { - operandIdxMap[x] = idx; - } - moveOperationsToCurrentForBody(groupIdx, b, inductionVars, - operandIdxMap, loopState, - accRelatedOps); - auto &grpResults = getGroupOpResults()[groupIdx]; - // next for loop - llvm::SmallVector iterArgsArray; - iterArgsArray.emplace_back(multiReductionAcc); - std::queue tmpSourceOps(sourceRelatedOps); - while (!tmpSourceOps.empty()) { - auto cur = tmpSourceOps.front(); - tmpSourceOps.pop(); - auto curResults = cur->getResults(); - for (auto x : curResults) { - if (grpResults.contains(x)) { - for (auto y : cur->getOperands()) { - if (grpArgs.contains(y)) { - iterArgsArray.emplace_back(y); - } + // analysis acc related operation + classifySourceRelatedOps( + accRelatedOps, sourceRelatedOps, + multiReductionOp.getSource().getDefiningOp(), prevOps); + + rewriteOperationAsVectorize(b, groupIdx, accRelatedOps); + auto &grpArgs = getGroupOpIterArgs()[groupIdx]; + llvm::DenseMap operandIdxMap; + for (auto [idx, x] : llvm::enumerate(grpArgs)) { + operandIdxMap[x] = idx; + } + moveOperationsToCurrentForBody(groupIdx, b, inductionVars, + operandIdxMap, loopState, + accRelatedOps); + auto &grpResults = getGroupOpResults()[groupIdx]; + // next for loop + llvm::SmallVector iterArgsArray; + iterArgsArray.emplace_back(multiReductionAcc); + std::queue tmpSourceOps(sourceRelatedOps); + while (!tmpSourceOps.empty()) { + auto cur = tmpSourceOps.front(); + tmpSourceOps.pop(); + auto curResults = cur->getResults(); + for (auto x : curResults) { + if (grpResults.contains(x)) { + for (auto y : cur->getOperands()) { + if (grpArgs.contains(y)) { + iterArgsArray.emplace_back(y); } } } } - ValueRange reductionAxisArgs(iterArgsArray); - auto nxtFor = parallelAxisGenerateForLoop( - b, groupIdx, parallelIdx + 1, reductionAxisArgs, - inductionVars, originalWriteResult); + } + ValueRange reductionAxisArgs(iterArgsArray); + auto nxtFor = parallelAxisGenerateForLoop( + b, groupIdx, parallelIdx + 1, reductionAxisArgs, + inductionVars, originalWriteResult); + if (hasPrevOps and !hasPostOps) { + // prevOp + reduction op + + } else { rewriteOperationAsVectorize(b, groupIdx, postOps); - moveOperationsToCurrentForBody(groupIdx, b, loopState, - operandIdxMap, inductionVars, + moveOperationsToCurrentForBody(groupIdx, b, inductionVars, + operandIdxMap, loopState, postOps); auto replaceIfFn = [&](OpOperand &use) { @@ -1362,7 +1382,7 @@ LogicalResult CanonicalizerVectorOperation::canonicalizeReductionOperation() { continue; } // generate MultiReduction for loops - // (void)generateMultiReductionForLoop(groupId); + (void)generateMultiReductionForLoop(groupId); } return success(); } @@ -1417,6 +1437,7 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. analysisGroupOperaion(); + printGroupOps(getFusionStrategy().getOpGroups()); // Speical Operation Canonicalization canonicalizeSpecialOperation(); @@ -1507,8 +1528,9 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( // inner most body of the loop if (forDimIdx == dims.size() - 1) { - moveOperationsToCurrentForBody(groupIdx, b, inductionVars, - operandIdxMap, loopState); + moveOperationsToCurrentForBody( + groupIdx, b, inductionVars, operandIdxMap, loopState, + getFusionStrategy().getOpGroups()[groupIdx]); auto &resultSet = getGroupOpResults()[groupIdx]; maybeYieldValue(b, loc, resultSet.getArrayRef()); } else { @@ -1522,7 +1544,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( return forOp; } -bool isCompatibleVectorType(Operation *op1, Operation *op2) { +bool isSameVectorType(Operation *op1, Operation *op2) { auto type1 = getOperationVectorType(op1); auto type2 = getOperationVectorType(op2); if (failed(type1) || failed(type2)) { @@ -1530,22 +1552,42 @@ bool isCompatibleVectorType(Operation *op1, Operation *op2) { } auto sp1 = type1.value(); auto sp2 = type2.value(); - if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { - if (sp1.getRank() != sp2.getRank()) { - return false; - } - for (long i = 0; i < sp1.getRank(); i++) { - if (sp1.getDimSize(i) != sp2.getDimSize(i)) { - return false; - } + if (sp1.getRank() != sp2.getRank()) { + return false; + } + bool isSame = true; + // from front to back + for (long i = 0; i < sp1.getRank(); i++) { + if (sp1.getDimSize(i) != sp2.getDimSize(i)) { + isSame = false; + break; } } - if (sp1.getRank() != sp2.getRank()) { + return isSame; +} + +bool isCompatibleVectorType(Operation *op1, Operation *op2) { + auto type1 = getOperationVectorType(op1); + auto type2 = getOperationVectorType(op2); + if (failed(type1) || failed(type2)) { return false; } + auto sp1 = type1.value(); + auto sp2 = type2.value(); + // if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { + // if (sp1.getRank() != sp2.getRank()) { + // return false; + // } + // for (long i = 0; i < sp1.getRank(); i++) { + // if (sp1.getDimSize(i) != sp2.getDimSize(i)) { + // return false; + // } + // } + // } bool isCompatible = true; + auto min_rank = std::min(sp1.getRank(), sp2.getRank()); // from front to back - for (long i = 0; i < sp1.getRank(); i++) { + for (long i = 0; i < min_rank; i++) { if (sp1.getDimSize(i) != sp2.getDimSize(i)) { isCompatible = false; break; @@ -1740,14 +1782,17 @@ bool hasDataDependency(Operation *op1, Operation *op2) { }) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { - // op1 is special operation, op2 is normal operation - // op1 and op2 is both speicial operation + // has two cases: op1 is special operation, op2 is normal + // operation op1 and op2 is both speicial operation llvm::SmallVector dims2, reductionDims, parallelDims; getOperationDataAxis(op1, reductionDims); getOperationDataAxis(op2, dims2); llvm::DenseSet checkSet(dims2.begin(), dims2.end()); - + auto op2VectorType = getOperationVectorType(op2); if (!isSpecialOp(op2)) { + if (isSameVectorType(op1, op2)) { + return false; + } // all reduction axis should be op2's data axis bool reduceDependent = false; for (auto x : reductionDims) { @@ -1771,9 +1816,7 @@ bool hasDataDependency(Operation *op1, Operation *op2) { } checkSet.clear(); checkSet.insert(parallelDims.begin(), parallelDims.end()); - auto rank = - mlir::dyn_cast(op2->getResultTypes()[0]) - .getRank(); + auto rank = op2VectorType->getRank(); for (auto i = 0; i < rank; i++) { if (!checkSet.contains(i)) { return true; @@ -1781,15 +1824,16 @@ bool hasDataDependency(Operation *op1, Operation *op2) { } return false; - } else { - // TODO: reduce operation fused with other special operation - if (mlir::isa(op2)) { - return true; - } else if (mlir::isa(op2)) { - return true; - } - //... } + // else { + // // TODO: reduce operation fused with other special + // operation if (mlir::isa(op2)) { + // return true; + // } else if (mlir::isa(op2)) { + // return true; + // } + // //... + // } return true; }) @@ -1797,6 +1841,7 @@ bool hasDataDependency(Operation *op1, Operation *op2) { llvm::SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); + return true; if (!isSpecialOp(op2)) { return hasSameAxis(dims1, dims2); } else { @@ -1807,6 +1852,7 @@ bool hasDataDependency(Operation *op1, Operation *op2) { llvm::SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); + return true; if (!isSpecialOp(op2)) { return hasSameAxis(dims1, dims2); } else { @@ -1818,8 +1864,7 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return res; } -bool isNeedNewGroup(llvm::SmallVector, 8> &opGroups, - Operation *op) { +bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { // 1. check previous operation if (!opGroups.back().empty()) { auto prevOp = opGroups.back().back(); @@ -1834,20 +1879,15 @@ bool isNeedNewGroup(llvm::SmallVector, 8> &opGroups, // previous operation vector type is not compatible with current operation if (!isCompatibleVectorType(prevOp, op)) { - // TODO: Support partial compatible operation fusion - if (isPartialCompatible(prevOp, op)) { - } return true; } } return false; } -void addOperationToGroup( - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap, Operation *op) { - // - if (isNeedNewGroup(opGroups, op)) { +void VectorFusionStrategy::addOperationToGroup(Operation *op) { + assert(op); + if (isNeedNewGroup(op)) { opGroups.emplace_back(std::queue()); } opGroups.back().push(op); @@ -1864,7 +1904,7 @@ void VectorFusionStrategy::classifyOperations() { } func->walk([&](Operation *op) { if (filterOperation(op)) { - addOperationToGroup(opGroups, opGroupIndexMap, op); + addOperationToGroup(op); } }); } @@ -1977,24 +2017,27 @@ void ForLoopGenerator::createNewConstantOp( /// Rewrite the operations in the group to vectorized form. void ForLoopGenerator::rewriteOperationAsVectorize( OpBuilder &rewriter, size_t groupId, const std::queue &queue) { - auto &groupOps = + const std::queue groupOps = queue.empty() ? getFusionStrategy().getOpGroups()[groupId] : queue; - auto &opMap = getFusionStrategy().getOpGroupIndexMap(); - auto &opPermuationMap = getOpPermuationMap(); + const llvm::DenseMap &opMap = + getFusionStrategy().getOpGroupIndexMap(); + llvm::DenseMap &opPermuationMap = + getOpPermuationMap(); std::queue transformQueue(groupOps); - auto groupSteps = getFusionStrategy().getGroupMaxSteps()[groupId]; + size_t groupSteps = getFusionStrategy().getGroupMaxSteps()[groupId]; while (!transformQueue.empty()) { - auto op = transformQueue.front(); + Operation *op = transformQueue.front(); transformQueue.pop(); + VectorType newOperandType = getVectorzedType(op, groupSteps); auto lowerResult = TypeSwitch(op) .Case( [&](vector::TransferWriteOp transferWriteOp) { IRRewriter rewriter(transferWriteOp); - auto newOperandType = - getVectorzedType(transferWriteOp, groupSteps); - auto srcOp = transferWriteOp->getOperand(0).getDefiningOp(); + + Operation *srcOp = + transferWriteOp->getOperand(0).getDefiningOp(); if (mlir::isa(srcOp)) { createNewConstantOp(srcOp, &transferWriteOp); } else { @@ -2013,8 +2056,6 @@ void ForLoopGenerator::rewriteOperationAsVectorize( }) .Case( [&](vector::TransferReadOp transferReadOp) { - auto newOperandType = - getVectorzedType(transferReadOp, groupSteps); opPermuationMap.insert( {transferReadOp, transferReadOp.getPermutationMap()}); transferReadOp->getResult(0).setType(newOperandType); @@ -2032,8 +2073,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( "It should not appear this operation."); return failure(); }) - .Case([&](arith::ExtFOp extfOp) { - auto newOperandType = getVectorzedType(extfOp, groupSteps); + .Case([&](Operation *extfOp) { extfOp->getResult(0).setType(newOperandType); return success(); }) @@ -2043,8 +2083,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( "It should not appear this operation."); return failure(); } - setOperationOperandResult(op, getVectorzedType(op, groupSteps), - opMap); + setOperationOperandResult(op, newOperandType, opMap); return success(); }); if (failed(lowerResult)) { @@ -2075,7 +2114,7 @@ void updateOpOperandResultInGroups( llvm::SmallVector, 8> &opGroups, llvm::DenseMap &opGroupIndexMap, size_t opGid, Operation *op, Value &init, const Value &result = Value()) { - auto tmpOpQueue(opGroups[opGid]); + std::queue tmpOpQueue(opGroups[opGid]); std::queue newOpQueue; while (!tmpOpQueue.empty()) { auto curOp = tmpOpQueue.front(); @@ -2322,10 +2361,13 @@ struct CPUPhysicalRegisterPass return; } // canonicalize vector operation, default use vector-based fusion strategy. + HardWareInfo hwInfo; + // default has avx512f instructions + // hwInfo.favx512f = false; CanonicalizerVectorOperation canonicalizer( - func, CanonicalizerKind::OperationsGroup); + func, CanonicalizerKind::OperationsGroup, hwInfo); canonicalizer.run(); - // patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 04b904f7c..60d572d2d 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -141,3 +141,77 @@ func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf return %3 : tensor<8192x16384xf32> } +#map = affine_map<(d0) -> (d0 * 64)> +#map1 = affine_map<(d0) -> (d0 * 128)> +#map2 = affine_map<(d0) -> (d0 floordiv 16)> +#map3 = affine_map<(d0) -> (d0 floordiv 32)> +#map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> +#map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> + func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = tensor.empty() : tensor<512x256xbf16> + %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { + %3 = affine.apply #map(%arg3) + %4 = affine.apply #map1(%arg4) + %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> + %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> + %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> + %11 = affine.apply #map2(%arg28) + %12 = affine.apply #map3(%arg24) + %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> + %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> + %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> + %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> + %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> + %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %15 = affine.apply #map4(%arg12, %arg24, %arg4) + %16 = affine.apply #map5(%arg8, %arg20, %arg3) + %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> + %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> + %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] + %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + } + } + return %2#2 : tensor<128x256xbf16> + } From 380f173ea94b1d2e38bbf257d1e8efa4c5a28484 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 18 Jul 2024 17:08:02 +0800 Subject: [PATCH 16/66] refactor partial compitable operation fusion --- include/gc/Transforms/TilingVector.h | 283 +++- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 1382 ++++++++++------- .../gc/transforms/cpu-vetor-distribution.mlir | 148 +- 3 files changed, 1093 insertions(+), 720 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index 4a85002b4..a1e1b84e8 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -32,6 +32,7 @@ #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -50,7 +51,7 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); void setOperationCorrectOperand( Operation *op, const ValueRange &iterArgs, const llvm::DenseMap &operandIdxMap, - const llvm::SmallVector &inductionVars, + ArrayRef inductionVars, const llvm::DenseMap &opPermuationMap); struct HardWareInfo { @@ -65,9 +66,9 @@ class TypeHelper { public: void setHardWareInfo(HardWareInfo &info) { HWInfo = info; } - int getDataTypeValidSteps(const VectorType &type); + int getDataTypeValidSteps(VectorType type); int generateValidSteps(int steps, VectorType type); - int getDataTypeMAXSIMDLength(const VectorType &type); + int getDataTypeMAXSIMDLength(VectorType type); VectorType getVectorzedType(Operation *op, uint32_t loopStep = 0); }; @@ -86,18 +87,39 @@ class TypeHelper { /// Using queue to store the operation order. In order to ensure that /// subsequent moves to the operation will not cause semantic changes. -class VectorFusionStrategy { +class VectorFusionStrategy : public TypeHelper { private: + func::FuncOp func; llvm::SmallVector, 8> opGroups; llvm::SmallVector groupMaxSteps; + /// vector type which has bigest rank in current operation group + llvm::SmallDenseMap groupBigestRankVectorType; /// query current operation in which group, return group index llvm::DenseMap opGroupIndexMap; /// can fused into prev operation which axis position - llvm::DenseMap opAnchorPos; - - func::FuncOp func; + llvm::DenseMap opAnchorPos; public: + VectorFusionStrategy() = default; + VectorFusionStrategy(func::FuncOp &func) : func(func) {} + VectorFusionStrategy(func::FuncOp &func, TypeHelper &typeHelper) + : TypeHelper(typeHelper), func(func) {} + VectorFusionStrategy(VectorFusionStrategy &strategy) + : func(strategy.func), opGroups(strategy.opGroups), + groupMaxSteps(strategy.groupMaxSteps), + opGroupIndexMap(strategy.opGroupIndexMap), + opAnchorPos(strategy.opAnchorPos){}; + VectorFusionStrategy(VectorFusionStrategy &&strategy) + : func(std::move(strategy.func)), opGroups(std::move(strategy.opGroups)), + groupMaxSteps(std::move(strategy.groupMaxSteps)), + opGroupIndexMap(std::move(strategy.opGroupIndexMap)), + opAnchorPos(std::move(strategy.opAnchorPos)){}; + + VectorFusionStrategy &operator=(VectorFusionStrategy &&) = default; + + llvm::SmallDenseMap &getGroupBiggestRankVectorType() { + return groupBigestRankVectorType; + }; llvm::SmallVector, 8> &getOpGroups() { return opGroups; } @@ -105,14 +127,17 @@ class VectorFusionStrategy { return opGroupIndexMap; } llvm::SmallVector &getGroupMaxSteps() { return groupMaxSteps; } + llvm::DenseMap &getOpAnchorPos() { return opAnchorPos; } - func::FuncOp getFunc() { return func; } - - VectorFusionStrategy() = default; - VectorFusionStrategy(func::FuncOp func) : func(func) {} + func::FuncOp &getFunc() { return func; } void classifyOperations(); + /// Whether two operations have compatible vector shapes + bool isCompatibleVectorType(Operation *op1, Operation *op2); + + void updateGroupBitgestVectorType(VectorType vectorType); + /// Check whether the operation can fuse with previous operation bool isNeedNewGroup(Operation *op); @@ -130,13 +155,26 @@ template class SpecialOperationCanonicalizer { private: llvm::SmallVector candidateRdOps; +public: + enum class SpecialOperationKind { + OP_MultiDimReduction, + OP_Broadcast, + OP_Transpose, + OP_ShapeCast + }; + +private: + const SpecialOperationKind kind; + public: SpecialOperationCanonicalizer() = default; - SpecialOperationCanonicalizer(const llvm::SmallVector &candidateRdOps) - : candidateRdOps(candidateRdOps) {} - virtual ~SpecialOperationCanonicalizer() {} + SpecialOperationCanonicalizer(const llvm::SmallVector &candidateRdOps, + SpecialOperationKind kind) + : candidateRdOps(candidateRdOps), kind(kind) {} llvm::SmallVector &getCandidateOps(); + virtual ~SpecialOperationCanonicalizer() {} virtual void prepareSpecialOperationInfo() = 0; + SpecialOperationKind getKind() { return kind; } }; enum class MultiReduceOpAxisKind { Reduction, Parallel }; @@ -156,7 +194,7 @@ class MultiReductionCanonicalizer MultiReductionCanonicalizer( const llvm::SmallVector &candidateRdOps) : SpecialOperationCanonicalizer( - candidateRdOps) { + candidateRdOps, SpecialOperationKind::OP_MultiDimReduction) { isStandaloneOp = candidateRdOps.size() == 1; prepareSpecialOperationInfo(); }; @@ -174,13 +212,19 @@ class MultiReductionCanonicalizer std::queue &getAccRelatedOps() { return accRelatedOps; } std::queue &getSourceRelatedOps() { return sourceRelatedOps; } llvm::SetVector &getOriginalOpResults() { return originalOpResults; } - VectorType &getSourceType() { return sourceType; }; - VectorType &getAccType() { return accType; }; + VectorType getSourceType() { return sourceType; }; + VectorType getAccType() { return accType; }; llvm::SmallDenseMap &getResultIdxMap() { return resultIdxMap; } void setResultIdxMap(const llvm::SmallDenseMap &map) { resultIdxMap = map; } + void prepareSpecialOperationInfo() override; + + static bool classof(SpecialOperationCanonicalizer *canonicalizer) { + return canonicalizer->getKind() == + SpecialOperationKind::OP_MultiDimReduction; + } }; class BroadcastCanonicalizer @@ -189,9 +233,13 @@ class BroadcastCanonicalizer public: BroadcastCanonicalizer( const llvm::SmallVector &candidateBcOps) - : SpecialOperationCanonicalizer(candidateBcOps){}; + : SpecialOperationCanonicalizer( + candidateBcOps, SpecialOperationKind::OP_Broadcast){}; virtual ~BroadcastCanonicalizer() {} void prepareSpecialOperationInfo() override {} + static bool classof(SpecialOperationCanonicalizer *canonicalizer) { + return canonicalizer->getKind() == SpecialOperationKind::OP_Broadcast; + } }; class TransposeCanonicalizer @@ -200,9 +248,13 @@ class TransposeCanonicalizer public: TransposeCanonicalizer( const llvm::SmallVector &candidateTpOps) - : SpecialOperationCanonicalizer(candidateTpOps){}; + : SpecialOperationCanonicalizer( + candidateTpOps, SpecialOperationKind::OP_Transpose){}; virtual ~TransposeCanonicalizer() {} void prepareSpecialOperationInfo() override {} + static bool classof(SpecialOperationCanonicalizer *canonicalizer) { + return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose; + } }; class ShapeCastCanonicalizer @@ -211,16 +263,30 @@ class ShapeCastCanonicalizer public: ShapeCastCanonicalizer( const llvm::SmallVector &candidateScOps) - : SpecialOperationCanonicalizer(candidateScOps){}; + : SpecialOperationCanonicalizer( + candidateScOps, SpecialOperationKind::OP_ShapeCast){}; virtual ~ShapeCastCanonicalizer() {} void prepareSpecialOperationInfo() override {} + static bool classof(SpecialOperationCanonicalizer *canonicalizer) { + return canonicalizer->getKind() == SpecialOperationKind::OP_ShapeCast; + } +}; + +enum class ReturnTypeKind { + RT_Both, + RT_OutGroup, + RT_InGroup, }; -class CanonicalizerCommonUsedData : virtual public TypeHelper { +class CanonicalizerCommonUsedData : public TypeHelper { private: VectorFusionStrategy fusionStrategy; - // analysis the operation's operands and results - llvm::SmallVector, 8> groupOpResults, groupOpIterArgs; + +private: + /// analysis the operation's operands and results + SmallVector>, 8> + groupOpResults; + llvm::SmallVector, 8> groupOpInitArgs; // store read and write operations permutation maps in order to convenient // to replace loop induction var @@ -237,48 +303,56 @@ class CanonicalizerCommonUsedData : virtual public TypeHelper { CanonicalizerCommonUsedData( VectorFusionStrategy &fusionStrategy, - llvm::SmallVector, 8> &groupOpResults, - llvm::SmallVector, 8> &groupOpIterArgs, + llvm::SmallVector< + llvm::MapVector>, 8> + &groupOpResults, + llvm::SmallVector, 8> &groupOpInitArgs, llvm::DenseMap &opPermuationMap) : fusionStrategy(fusionStrategy), groupOpResults(groupOpResults), - groupOpIterArgs(groupOpIterArgs), opPermuationMap(opPermuationMap) {} + groupOpInitArgs(groupOpInitArgs), opPermuationMap(opPermuationMap) {} virtual ~CanonicalizerCommonUsedData(){}; - // set methods - void setFuseStrategy(VectorFusionStrategy &strategy) { - fusionStrategy = strategy; - auto opGroups = fusionStrategy.getOpGroups(); + /// Set fusion strategy + void setFuseStrategy(VectorFusionStrategy &&strategy) { + fusionStrategy = std::move(strategy); + llvm::SmallVector, 8> &opGroups = + fusionStrategy.getOpGroups(); + // init operations results and initialization args if (opGroups.size() != groupOpResults.size() || - opGroups.size() != groupOpIterArgs.size()) { + opGroups.size() != groupOpInitArgs.size()) { groupOpResults.clear(); - groupOpIterArgs.clear(); + groupOpInitArgs.clear(); for (size_t i = 0; i < opGroups.size(); i++) { - groupOpResults.emplace_back(llvm::SetVector()); - groupOpIterArgs.emplace_back(llvm::SetVector()); + groupOpResults.emplace_back( + llvm::MapVector>()); + groupOpInitArgs.emplace_back(llvm::SetVector()); } } } - void - setGroupOpResults(llvm::SmallVector, 8> &results) { - groupOpResults = results; + void setGroupOpResults( + const SmallVector< + llvm::MapVector>, 8> + &results) { + groupOpResults = std::move(results); } - void - setGroupOpIterArgs(llvm::SmallVector, 8> &iterArgs) { - groupOpIterArgs = iterArgs; + void setGroupOpIterArgs( + const llvm::SmallVector, 8> &initArgs) { + groupOpInitArgs = std::move(initArgs); } - void setPermutationMap(llvm::DenseMap &map) { - opPermuationMap = map; + void setPermutationMap(const llvm::DenseMap &map) { + opPermuationMap = std::move(map); } // get methods VectorFusionStrategy &getFusionStrategy() { return fusionStrategy; } - llvm::SmallVector, 8> &getGroupOpResults() { + SmallVector>, 8> & + getGroupOpResults() { return groupOpResults; } - llvm::SmallVector, 8> &getGroupOpIterArgs() { - return groupOpIterArgs; + llvm::SmallVector, 8> &getGroupOpInitArgs() { + return groupOpInitArgs; } llvm::DenseMap &getOpPermuationMap() { @@ -307,40 +381,108 @@ class CanonicalizerCommonUsedData : virtual public TypeHelper { }; class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { +private: func::FuncOp func; public: + ForLoopGenerator() = default; + ForLoopGenerator(func::FuncOp &func) : func(func) {} + virtual ~ForLoopGenerator() {} void setGeneratorFunc(func::FuncOp &func) { this->func = func; } void generateGroupOpVectorizedIR(const int idx); - void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, - const std::queue &queue = {}); + void + rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, + const std::queue *queue = nullptr); void createNewConstantOp(Operation *srcOp, vector::TransferWriteOp *transferWriteOp); // elementwise for loop mlir::FailureOr generateVectorizedForLoop(const size_t groupId, IRRewriter &rewriter, - const VectorType &vectorType); + const VectorType vectorType); + scf::ForOp constructNestedForOp(const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, const Location &loc, - const ValueRange &iterArgs, const VectorType &type, + const ValueRange &iterArgs, VectorType type, const llvm::ArrayRef &dims, llvm::SmallVector &inductionVars, const llvm::DenseMap &operandIdxMap); void moveOperationsToCurrentForBody( - const size_t groupIdx, OpBuilder &b, - const llvm::SmallVector &inductionVars, + const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, const llvm::DenseMap &operandIdxMap, const ValueRange &loopState, std::queue &queue); + void getResultInCurrentOps(const size_t anchorIdx, const size_t groupId, + const std::queue ops, + SmallVector &results, + DenseMap &forResultOrignalResultMap); + void + getInitArgsToNextAnchor(const size_t anchorIdx, const size_t groupId, + const std::queue &nextOperations, + const ValueRange &loopState, + llvm::DenseMap ¤tLoopStateIdxMap, + llvm::DenseMap &nextAnchorArgsIdxMap, + llvm::SmallVector &nextAnchorArgs, + DenseMap &originalOperandLoopArgsMap); + + void getOperationInCurrentAnchor(const size_t anchorIdx, + std::queue &fromQueue, + std::queue &toQueue); + void generateLoopResults(OpBuilder &b, const Location &loc, + const size_t anchorIdx, const size_t groupIdx, + llvm::SmallVector &nextAnchorResults, + llvm::DenseMap &nextAnchorResultsIdxMap, + const ValueRange &forResults, + const std::queue &movedOperaiton, + DenseMap &forResultOrignalResultMap); + + void movePostOpToCurrentAnchor( + OpBuilder &b, const int anchorIdx, const int groupIdx, + const ValueRange &forResults, const Block *forBlock, + std::queue &candidateOps, + std::queue &movedOperation, ArrayRef inductionVars, + const llvm::DenseMap &operandIdxMap, + const ValueRange &loopState, + const llvm::SmallVector &nextAnchorResults, + DenseMap &forResultOrignalResultMap); + + void + movePreOpToCurrentAnchor(const size_t anchorIdx, const size_t groupIdx, + OpBuilder &b, ArrayRef inductionVars, + const ValueRange &loopState, + llvm::DenseMap ¤tLoopStateIdxMap, + llvm::DenseMap &nextLoopStateIdxMap, + llvm::SmallVector &nextAnchorArgs, + std::queue &candidateQueue, + std::queue &movedQueue, + DenseMap &originalOperandLoopArgsMap); + + void replaceOperationsWithForLoopResult( + IRRewriter &rewrite, const ValueRange &forResults, const Block *forBlock, + const llvm::SmallVector &nextAnchorResults, + const std::queue movingOperations, + DenseMap &forResultOrignalResultMap); // multireduction forloop methods scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); - scf::ForOp - parallelAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, - const size_t parallelIdx, ValueRange &initArgs, - llvm::SmallVector &inductionVars, - Value &originalWriteResult); + scf::ForOp reductionAxisGenerateForLoop( + OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, + const int anchorIdx, llvm::DenseMap ¤tLoopStateIdxMap, + const ValueRange &initArgs, + llvm::SmallVector &nextAnchorResults, + llvm::DenseMap &nextAnchorResultsIdxMap, + llvm::SmallVector &inductionVars, + DenseMap &forResultOrignalResultMap); + + scf::ForOp parallelAxisGenerateForLoop( + OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, + llvm::DenseMap ¤tLoopStateIdxMap, + const ValueRange &initArgs, + llvm::SmallVector &nextAnchorResults, + llvm::DenseMap &nextAnchorResultsIdxMap, + llvm::SmallVector &inductionVars, + DenseMap &originalOperandLoopArgsMap, + DenseMap &forResultOrignalResultMap); scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, @@ -360,20 +502,34 @@ class VectorOperationAnalysizer : virtual public CanonicalizerCommonUsedData { public: virtual ~VectorOperationAnalysizer(){}; + VectorOperationAnalysizer() {} + VectorOperationAnalysizer(func::FuncOp &func) : func(func) {} void generateEmptyTensorAndWrite( - Operation *sourceOp, llvm::DenseMap> - &srcOpCanoniclizedMap); + Operation *sourceOp, + llvm::DenseMap> + &srcOpCanoniclizedMap, + size_t anchorPos, ReturnTypeKind retKind); void setAnalysisFunc(func::FuncOp &func) { this->func = func; } void analysisEmptyGroupAndMaxSteps(); void analysisGroupOperaion(); void analysisGroupOperationResults(); + Value + canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, + size_t operandIdx, + vector::TransferReadOp *srcReadOp = nullptr); + void updateOpOperandResultInGroups(size_t opGid, Operation *op, Value &init, + const Value &result = Value()); + Operation * + createTransferReadOpBefore(Operation *op, const Value &operand, + vector::TransferReadOp *srcReadOp = nullptr); }; - -class CanonicalizerVectorOperation : virtual public VectorOperationAnalysizer, - ForLoopGenerator { +/// Vectorize vector operation with target machines simd instructions. +class CanonicalizerVectorOperation : virtual public ForLoopGenerator, + VectorOperationAnalysizer { private: func::FuncOp func; IRRewriter rewriter; + CanonicalizerKind kind; public: @@ -387,9 +543,9 @@ class CanonicalizerVectorOperation : virtual public VectorOperationAnalysizer, setHardWareInfo(hwInfo); // vector operation fusion if (kind == CanonicalizerKind::OperationsGroup) { - auto fusionStrategy = VectorFusionStrategy(func); + VectorFusionStrategy fusionStrategy(func); fusionStrategy.run(); - setFuseStrategy(fusionStrategy); + setFuseStrategy(std::move(fusionStrategy)); } } virtual ~CanonicalizerVectorOperation(){}; @@ -399,7 +555,6 @@ class CanonicalizerVectorOperation : virtual public VectorOperationAnalysizer, IRRewriter &getIRWewriter() { return rewriter; } // void canonicalizeSpecialOperation(); - LogicalResult canonicalizeReductionOperation(); void clearSpecialOperationCanonicalizers(); void dummyInitSpecialOperation(); void initSpeicalOperationCanonicalizers(); diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 64044d406..b515b6593 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -25,7 +25,7 @@ namespace { // TODO: remove it in the future bool disableSpecialOp = false; -void printGroupOps(llvm::SmallVector, 8> &opGroups) { +void printGroupOps(SmallVector, 8> &opGroups) { for (auto grp : opGroups) { if (grp.empty()) { continue; @@ -42,13 +42,25 @@ void printGroupOps(llvm::SmallVector, 8> &opGroups) { } } +void printQueue(const std::queue &opQueue) { + std::cout << "________________________________ op Queue " + "__________________" + << std::endl; + auto tempQ(opQueue); + while (!tempQ.empty()) { + auto cur = tempQ.front(); + cur->dump(); + tempQ.pop(); + } + std::cout << "________________________________ op queue end " + "__________________" + << std::endl; +} + bool isSpecialOp(Operation *op) { - return llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op); + return isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op); } bool is_innermost_operation(Operation *op) { @@ -64,16 +76,13 @@ bool is_innermost_operation(Operation *op) { } bool isNotSupportOperation(Operation *op) { - return llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op); + return isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op); } bool isReadOrWriteOperation(Operation *op) { - return llvm::isa(op) || - llvm::isa(op); + return isa(op) || isa(op); } // TODO: Need to support these operations in the future @@ -88,7 +97,7 @@ bool hasNotSupportOperation(func::FuncOp *func) { } // select nearest even step -int get_nearest_vector_step(const int step) { +int getNearestVectorStep(const int step) { assert(step > 0); int nbits = 0, n = step; while (n) { @@ -102,7 +111,7 @@ int get_nearest_vector_step(const int step) { int TypeHelper::generateValidSteps(int steps, VectorType type) { return type.getShape().back() >= steps ? (steps > 16 ? 16 : steps) - : get_nearest_vector_step(type.getShape().back()); + : getNearestVectorStep(type.getShape().back()); } // expr equals `vector rank` - 1 @@ -111,7 +120,7 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { mlir::dyn_cast(expr).getPosition() == rank - 1; } -[[nodiscard]] int TypeHelper::getDataTypeValidSteps(const VectorType &type) { +[[nodiscard]] int TypeHelper::getDataTypeValidSteps(VectorType type) { auto typebits = type.getElementTypeBitWidth(); const int favx512bits = 512; const int favx2bits = 256; @@ -128,7 +137,7 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { } // Get the maximum number of current data types that a register can hold -[[nodiscard]] int TypeHelper::getDataTypeMAXSIMDLength(const VectorType &type) { +[[nodiscard]] int TypeHelper::getDataTypeMAXSIMDLength(VectorType type) { auto typebits = type.getElementTypeBitWidth(); const int favx512bits = 512; const int favx2bits = 256; @@ -207,6 +216,15 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loop_step) { return VectorType::get({loop_step}, vectorizedType.getElementType()); } +/// whether the operation result need to be returned +/// \param anchorIdx resuilt produce operation anchor position +/// \param retType resuilt return type +bool needReturnResult(std::pair &retType, + size_t anchorIdx) { + return !(retType.first == ReturnTypeKind::RT_InGroup and + retType.second >= anchorIdx); +} + union Float32Bits { uint32_t u; float f; @@ -219,8 +237,8 @@ const Float32Bits kF32Magic = {113 << kF32MantiBits}; const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; const uint32_t kF32BfMantiBitDiff = 16; -// Constructs the 16 bit representation for a half precision value from a float -// value. This implementation is adapted from Eigen. +/// Constructs the 16 bit representation for a half precision value from a float +/// value. This implementation is adapted from Eigen. uint16_t float2half(float floatValue) { const Float32Bits inf = {255 << kF32MantiBits}; const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; @@ -266,8 +284,8 @@ uint16_t float2half(float floatValue) { return halfValue; } -// Converts the 16 bit representation of a half precision value to a float -// value. This implementation is adapted from Eigen. +/// Converts the 16 bit representation of a half precision value to a float +/// value. This implementation is adapted from Eigen. float half2float(uint16_t halfValue) { const uint32_t shiftedExp = 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. @@ -440,19 +458,19 @@ T getInitValForReduce(vector::CombiningKind kind, Type t) { case vector::CombiningKind::ADD: if (t1.isIntOrIndex()) result = 0; - else if (llvm::isa(t1)) + else if (isa(t1)) result = 0.0f; else llvm_unreachable("invalid value types for ADD reduction"); break; case vector::CombiningKind::MAXNUMF: case vector::CombiningKind::MAXIMUMF: - assert(llvm::isa(t1) && "expected float values"); + assert(isa(t1) && "expected float values"); result = std::get(numeric_limits_minimum(t)); break; case vector::CombiningKind::MINNUMF: case vector::CombiningKind::MINIMUMF: - assert(llvm::isa(t1) && "expected float values"); + assert(isa(t1) && "expected float values"); result = std::get(numericLimitsMaximum(t)); break; case vector::CombiningKind::MAXSI: @@ -468,7 +486,7 @@ T getInitValForReduce(vector::CombiningKind kind, Type t) { case vector::CombiningKind::MUL: if (t1.isIntOrIndex()) result = 1; - else if (llvm::isa(t1)) + else if (isa(t1)) result = 1.f; else llvm_unreachable("invalid value types for MUL reduction"); @@ -506,12 +524,13 @@ void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, } // scf.for yield helper function -void maybeYieldValue(OpBuilder &b, Location loc, ValueRange value) { +scf::YieldOp maybeYieldValue(OpBuilder &b, Location loc, + const ValueRange &value) { bool hasRetVal = !value.empty(); if (hasRetVal) { - b.create(loc, value); + return b.create(loc, value); } else { - b.create(loc); + return b.create(loc); } } @@ -586,10 +605,9 @@ Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { /*inBounds=*/inBoundsVal); } -Operation * -createTransferReadOpBefore(Operation *op, const Value &operand, - vector::TransferReadOp *srcReadOp = nullptr) { - auto operandType = mlir::dyn_cast(operand.getType()); +Operation *VectorOperationAnalysizer::createTransferReadOpBefore( + Operation *op, const Value &operand, vector::TransferReadOp *srcReadOp) { + auto operandType = cast(operand.getType()); IRRewriter rewriter(op); auto zero = @@ -599,7 +617,7 @@ createTransferReadOpBefore(Operation *op, const Value &operand, rewriter.getZeroAttr(operandType.getElementType())); if (srcReadOp) { - auto resultType = mlir::dyn_cast(srcReadOp->getType()); + auto resultType = cast(srcReadOp->getType()); SmallVector inBoundsVal(resultType.getRank(), true); auto srcReadOpAffineMap = srcReadOp->getPermutationMap(); // result of read operation should be same as operand @@ -611,6 +629,9 @@ createTransferReadOpBefore(Operation *op, const Value &operand, /*indices=*/SmallVector(operandType.getRank(), zero), /**affinemap*/ srcReadOpAffineMap, /*inBounds=*/inBoundsVal); + DenseMap &permutationMap = getOpPermuationMap(); + permutationMap[t] = srcReadOpAffineMap; + getFusionStrategy().getOpAnchorPos()[t] = t.getVectorType().getRank() - 1; return t; } else { @@ -623,6 +644,10 @@ createTransferReadOpBefore(Operation *op, const Value &operand, /*indices=*/SmallVector(operandType.getRank(), zero), /**affinemap*/ padValue, /*inBounds=*/inBoundsVal); + DenseMap &permutationMap = getOpPermuationMap(); + permutationMap[t] = t.getPermutationMap(); + getFusionStrategy().getOpAnchorPos()[t] = t.getVectorType().getRank() - 1; + return t; } } @@ -631,18 +656,17 @@ createTransferReadOpBefore(Operation *op, const Value &operand, // result into the empty tensor [[nodiscard]] std::pair canonicalizeSourceOperation(Operation *op) { - // auto emtpyOp = createTensorEmptyBefore(op); auto resultTensor = getOperationResultTensor(op); auto writeOp = createTransferWriteOpAfter(op, resultTensor); return std::make_pair(resultTensor, writeOp->getResults()[0]); } -[[nodiscard]] Value -canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, - size_t operandIdx, - vector::TransferReadOp *srcReadOp = nullptr) { +[[nodiscard]] Value VectorOperationAnalysizer::canonicalizeCurrentOperation( + Operation *op, const Value &transferReadOperand, size_t operandIdx, + vector::TransferReadOp *srcReadOp) { // transfer_read operation auto readOp = createTransferReadOpBefore(op, transferReadOperand, srcReadOp); + op->setOperand(operandIdx, readOp->getResults()[0]); return readOp->getResults()[0]; } @@ -655,10 +679,10 @@ canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, // MultiReduce Operation //===----------------------------------------------------------------------===// -void getOpSourceOps(Operation *op, llvm::DenseSet &srcOps) { - llvm::SmallVector srcOperands = op->getOperands(); +void getOpSourceOps(Operation *op, DenseSet &srcOps) { + SmallVector srcOperands = op->getOperands(); std::deque srcOperandsQueue(srcOperands.begin(), srcOperands.end()); - llvm::DenseSet visited; + DenseSet visited; visited.insert(op); while (!srcOperandsQueue.empty()) { auto accOperand = srcOperandsQueue.front(); @@ -675,7 +699,7 @@ void getOpSourceOps(Operation *op, llvm::DenseSet &srcOps) { } } -bool isSrcRelated(const llvm::DenseSet &srcOps, Operation *op) { +bool isSrcRelated(const DenseSet &srcOps, Operation *op) { return srcOps.count(op); } @@ -718,7 +742,7 @@ void classifySourceRelatedOps(std::queue &accRelatedOps, std::queue &sourceRelatedOps, Operation *srcOp, std::queue &prevOps) { - llvm::DenseSet srcOps; + DenseSet srcOps; getOpSourceOps(srcOp, srcOps); while (!prevOps.empty()) { auto op = prevOps.front(); @@ -732,8 +756,8 @@ void classifySourceRelatedOps(std::queue &accRelatedOps, } void updateReduceReadWriteOperationOperand( - const llvm::SmallVector &inductionVars, - const llvm::SmallVector ¶llelAxis, Operation *op, + const SmallVector &inductionVars, + const SmallVector ¶llelAxis, Operation *op, MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel) { int indiceOffset = mlir::isa(op) ? 1 : 2; for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { @@ -747,8 +771,8 @@ void updateReduceReadWriteOperationOperand( vector::TransferReadOp ForLoopGenerator::cloneReductionTransferRead( Value &source, OpBuilder &b, IRMapping &readMap, - const llvm::SmallVector ¶llelAxis, - llvm::SmallVector &inductionVars, bool lastDimReduction, + const SmallVector ¶llelAxis, + SmallVector &inductionVars, bool lastDimReduction, MultiReduceOpAxisKind rdKind) { IRRewriter rewriter(b); auto readOp = mlir::dyn_cast(source.getDefiningOp()); @@ -777,8 +801,8 @@ vector::TransferReadOp ForLoopGenerator::cloneReductionTransferRead( vector::TransferWriteOp makeNewTransferWriteOp(Value source, IRMapping &writeMap, OpBuilder &b, - const llvm::SmallVector ¶llelAxis, - llvm::SmallVector &inductionVars) { + const SmallVector ¶llelAxis, + SmallVector &inductionVars) { IRRewriter bodyRewriter(b); auto writeOp = source.getDefiningOp(); auto newWriteOp = @@ -801,14 +825,14 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, } void ForLoopGenerator::moveOperationsToCurrentForBody( - const size_t groupIdx, OpBuilder &b, - const llvm::SmallVector &inductionVars, - const llvm::DenseMap &operandIdxMap, - const ValueRange &loopState, std::queue &opQueue) { + const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, + const DenseMap &operandIdxMap, const ValueRange &loopState, + std::queue &opQueue) { auto &opPermuationMap = getOpPermuationMap(); - while (!opQueue.empty()) { - auto x = opQueue.front(); - opQueue.pop(); + auto tmpQ(opQueue); + while (!tmpQ.empty()) { + auto x = tmpQ.front(); + tmpQ.pop(); x->moveBefore(b.getBlock(), b.getBlock()->end()); // check operation type to set correct operand setOperationCorrectOperand(x, loopState, operandIdxMap, inductionVars, @@ -833,94 +857,318 @@ bool hasOtherOperations(const std::queue &opQ, return res; }; +void ForLoopGenerator::getResultInCurrentOps( + const size_t anchorIdx, const size_t groupId, + const std::queue ops, SmallVector &results, + DenseMap &forResultOrignalResultMap) { + auto tmpQ(ops); + llvm::MapVector> &groupResults = + getGroupOpResults()[groupId]; + while (!tmpQ.empty()) { + Operation *cur = tmpQ.front(); + tmpQ.pop(); + auto curResult = cur->getResults()[0]; + if (groupResults.contains(curResult)) { + std::pair retType = groupResults[curResult]; + if (needReturnResult(retType, anchorIdx)) { + results.emplace_back(curResult); + forResultOrignalResultMap[curResult] = curResult; + } + } + } +} +void ForLoopGenerator::getInitArgsToNextAnchor( + const size_t anchorIdx, const size_t groupId, + const std::queue &nextOperations, const ValueRange &loopState, + DenseMap ¤tLoopStateIdxMap, + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs, + DenseMap &originalOperandLoopArgsMap) { + DenseMap &opAnchorPos = + getFusionStrategy().getOpAnchorPos(); + SetVector &opInitArgs = getGroupOpInitArgs()[groupId]; + DenseSet visited; + // find the next anchor arguments + std::queue tmpQ(nextOperations); + DenseMap nextOperandArgsMap; + while (!tmpQ.empty()) { + Operation *cur = tmpQ.front(); + tmpQ.pop(); + auto curOperands = cur->getOperands(); + for (auto x : curOperands) { + if (!visited.contains(x) and opInitArgs.contains(x) and + opAnchorPos[cur] > anchorIdx) { + int loopStateIdx = + currentLoopStateIdxMap[originalOperandLoopArgsMap[x]]; + nextAnchorArgs.emplace_back(loopState[loopStateIdx]); + nextOperandArgsMap[x] = loopState[loopStateIdx]; + nextAnchorArgsIdxMap[loopState[loopStateIdx]] = + nextAnchorArgs.size() - 1; + visited.insert(x); + } + } + } + originalOperandLoopArgsMap = nextOperandArgsMap; +} + +void ForLoopGenerator::getOperationInCurrentAnchor( + const size_t anchorIdx, std::queue &fromQueue, + std::queue &toQueue) { + while (!fromQueue.empty()) { + Operation *curOp = fromQueue.front(); + if (anchorIdx == getFusionStrategy().getOpAnchorPos()[curOp]) { + toQueue.push(curOp); + fromQueue.pop(); + continue; + } + break; + } +} + +void ForLoopGenerator::replaceOperationsWithForLoopResult( + IRRewriter &rewrite, const ValueRange &forResults, const Block *forBlock, + const llvm::SmallVector &nextAnchorResults, + const std::queue movingOperations, + DenseMap &forResultOrignalResultMap) { + auto tmpQ(movingOperations); + DenseSet operationOperands; + while (!tmpQ.empty()) { + auto curOp = tmpQ.front(); + tmpQ.pop(); + for (auto x : curOp->getOperands()) { + operationOperands.insert(x); + } + } + auto replaceIfFn = [&](OpOperand &use) { + return operationOperands.contains(use.get()); + }; + for (auto [nxtForResult, nextLoopResult] : + zip(forResults, nextAnchorResults)) { + Value originalResult = forResultOrignalResultMap[nextLoopResult]; + rewrite.replaceOpUsesWithIf(originalResult.getDefiningOp(), nxtForResult, + replaceIfFn); + } +} + +/// \param [out] nextLoopStateidxMap +/// \param [out] nextAnchorArgs +/// \param [out] movingQueue +void ForLoopGenerator::movePreOpToCurrentAnchor( + const size_t anchorIdx, const size_t groupIdx, OpBuilder &b, + ArrayRef inductionVars, const ValueRange &loopState, + DenseMap ¤tLoopStateIdxMap, + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs, + std::queue &candidateQueue, + std::queue &movedQueue, + DenseMap &originalOperandLoopArgsMap) { + + // 1. get operations in current anchor position + std::queue movingOperation; + getOperationInCurrentAnchor(anchorIdx, candidateQueue, movingOperation); + + // 2. get next anchor args + getInitArgsToNextAnchor(anchorIdx, groupIdx, candidateQueue, loopState, + currentLoopStateIdxMap, nextAnchorArgsIdxMap, + nextAnchorArgs, originalOperandLoopArgsMap); + + // 3. rewrite operation as vectorize IR + rewriteOperationAsVectorize(b, groupIdx, &movingOperation); + + // 4. move opeartions to current for block + moveOperationsToCurrentForBody(groupIdx, b, inductionVars, + currentLoopStateIdxMap, loopState, + movingOperation); + + // 5. move operations to moved queue + while (!movingOperation.empty()) { + movedQueue.push(movingOperation.front()); + movingOperation.pop(); + } +} + +void ForLoopGenerator::movePostOpToCurrentAnchor( + OpBuilder &b, const int anchorIdx, const int groupIdx, + const ValueRange &forResults, const Block *forBlock, + std::queue &candidateOps, std::queue &movedOps, + ArrayRef inductionVars, const DenseMap &operandIdxMap, + const ValueRange &loopState, const SmallVector &nextAnchorResults, + DenseMap &forResultOrignalResultMap) { + + // 1. move post-op to current loop body + std::queue movingOperations; + getOperationInCurrentAnchor(anchorIdx, candidateOps, movingOperations); + + rewriteOperationAsVectorize(b, groupIdx, &movingOperations); + + moveOperationsToCurrentForBody(anchorIdx, b, inductionVars, operandIdxMap, + loopState, movingOperations); + + // 2. replace correct for loop result to post-op + IRRewriter rewriter(b); + replaceOperationsWithForLoopResult(rewriter, forResults, forBlock, + nextAnchorResults, movingOperations, + forResultOrignalResultMap); + + // 3. move operations to moved queue + while (!movingOperations.empty()) { + movedOps.push(movingOperations.front()); + movingOperations.pop(); + } +} + +void ForLoopGenerator::generateLoopResults( + OpBuilder &b, const Location &loc, const size_t anchorIdx, + const size_t groupIdx, SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, const ValueRange &forResults, + const std::queue &movedOperation, + DenseMap &forResultOrignalResultMap) { + SmallVector results; + DenseMap currentResultMap; + getResultInCurrentOps(anchorIdx, groupIdx, movedOperation, results, + currentResultMap); + + llvm::MapVector> &groupResults = + getGroupOpResults()[groupIdx]; + // check for yield results whether need to return to next anchor + for (auto [idx, forResult] : llvm::enumerate(nextAnchorResults)) { + Value originalResult = forResultOrignalResultMap[forResult]; + + if (groupResults.contains(originalResult)) { + std::pair resultType = + groupResults[originalResult]; + if (needReturnResult(resultType, anchorIdx)) { + results.emplace_back(forResults[idx]); + currentResultMap[forResults[idx]] = originalResult; + } + } + } + + nextAnchorResults.clear(); + nextAnchorResultsIdxMap.clear(); + for (Value &result : results) { + nextAnchorResults.emplace_back(result); + nextAnchorResultsIdxMap[result] = nextAnchorResults.size() - 1; + } + forResultOrignalResultMap = std::move(currentResultMap); + + maybeYieldValue(b, loc, nextAnchorResults); +} + scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, - ValueRange &initArgs, llvm::SmallVector &inductionVars) { + const int anchorIdx, DenseMap ¤tLoopStateIdxMap, + const ValueRange &initArgs, SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + SmallVector &inductionVars, + DenseMap &forResultOrignalResultMap) { + MultiReductionCanonicalizer rdCanonicalizer = getMultiRdCanonicalizers()[groupIdx]; auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; + VectorFusionStrategy &fusionStrategy = getFusionStrategy(); + + SmallVector, 8> &opGroups = + fusionStrategy.getOpGroups(); + std::queue &opQueue = opGroups[groupIdx]; + const auto loc = multireductionOp->getLoc(); - auto &reductionAxis = rdCanonicalizer.getReductionAxis(); - auto lastDimReduction = rdCanonicalizer.hasLastDimReduction(); - auto vectorType = rdCanonicalizer.getSourceType(); + SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); + bool lastDimReduction = rdCanonicalizer.hasLastDimReduction(); + VectorType vectorType = rdCanonicalizer.getSourceType(); const int loopStep = getDataTypeValidSteps(vectorType); - auto isStandaloneOp = rdCanonicalizer.getIsStandAloneOp(); - auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); - auto forSteps = makeIndexArithConstantOp( + IRRewriter rewriterOfFunc(func); + + Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); + Value forSteps = makeIndexArithConstantOp( opBuilder, loc, (reductionIdx == reductionAxis.size() - 1 && lastDimReduction) ? loopStep : 1); - auto numIter = makeIndexArithConstantOp( + Value numIter = makeIndexArithConstantOp( opBuilder, loc, vectorType.getShape()[reductionAxis[reductionIdx]]); - auto forOp = opBuilder.create( + scf::ForOp forOp = opBuilder.create( loc, zero, numIter, forSteps, initArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { inductionVars.emplace_back(iv); - if (reductionIdx == reductionAxis.size() - 1) { - - if (isStandaloneOp) { - IRRewriter rewriter(b); - IRMapping readMap; - Value reductionTarget = multireductionOp.getSource(); - llvm::SmallVector parallelAxis; - auto newReadOp = cloneReductionTransferRead( - reductionTarget, b, readMap, parallelAxis, inductionVars, - lastDimReduction, MultiReduceOpAxisKind::Reduction); - auto reductionResult = - makeArithReduction(b, loc, multireductionOp.getKind(), - newReadOp->getResult(0), loopState.back()); - maybeYieldValue(b, loc, reductionResult); - } else { - auto &analysisResults = getGroupOpResults()[groupIdx]; - - auto &sourceOps = - getMultiRdCanonicalizers()[groupIdx].getSourceRelatedOps(); - auto &grpArgs = getGroupOpIterArgs()[groupIdx]; - - rewriteOperationAsVectorize(b, groupIdx, sourceOps); - llvm::DenseMap operandIdxMap; - llvm::SmallVector resultArray; - // dummy - resultArray.emplace_back(Value()); - std::queue tmpSourceOps(sourceOps); - // move operation into current for loop body - // accVal is first loopstate - int start = 1; - while (!tmpSourceOps.empty()) { - auto cur = tmpSourceOps.front(); - tmpSourceOps.pop(); - auto curOperands = cur->getOperands(); - for (auto x : curOperands) { - if (grpArgs.contains(x)) { - operandIdxMap[x] = start++; - } - } - if (analysisResults.contains(cur->getResults()[0])) { - resultArray.emplace_back(cur->getResults()[0]); - getMultiRdCanonicalizers()[groupIdx] - .getOriginalOpResults() - .insert(cur->getResults()[0]); - getMultiRdCanonicalizers()[groupIdx].getResultIdxMap().insert( - {cur->getResults()[0], resultArray.size() - 1}); - } + if (reductionIdx < reductionAxis.size() - 1) { + + // 1. move pre-Op to current body + DenseMap nextAnchorArgsIdxMap; + SmallVector nextAnchorArgs; + std::queue movedOperation; + DenseMap operandArgsMap; + movePreOpToCurrentAnchor(anchorIdx, groupIdx, b, inductionVars, + loopState, currentLoopStateIdxMap, + nextAnchorArgsIdxMap, nextAnchorArgs, + opQueue, movedOperation, operandArgsMap); + + // 2. generate next for loop + scf::ForOp nxtFor = reductionAxisGenerateForLoop( + b, groupIdx, reductionIdx + 1, anchorIdx + 1, + nextAnchorArgsIdxMap, nextAnchorArgs, nextAnchorResults, + nextAnchorResultsIdxMap, inductionVars, + forResultOrignalResultMap); + + // 3. move postOp to current body + movePostOpToCurrentAnchor( + rewriterOfFunc, anchorIdx, groupIdx, nxtFor->getResults(), + nxtFor->getBlock(), opQueue, movedOperation, inductionVars, + currentLoopStateIdxMap, loopState, nextAnchorResults, + forResultOrignalResultMap); + + // 4. rewrite operations as vectorized IR + rewriteOperationAsVectorize(b, groupIdx, &movedOperation); + + // 5. generate loop results + generateLoopResults(b, loc, anchorIdx, groupIdx, nextAnchorResults, + nextAnchorResultsIdxMap, nxtFor->getResults(), + movedOperation, forResultOrignalResultMap); + + } else if (reductionIdx == reductionAxis.size() - 1) { + std::queue movingOperation; + + while (!opQueue.empty()) { + Operation *curOp = opQueue.front(); + opQueue.pop(); + if (isa(curOp)) { + break; } - moveOperationsToCurrentForBody(groupIdx, b, inductionVars, - operandIdxMap, loopState, sourceOps); - - auto reductionResult = makeArithReduction( - b, loc, multireductionOp.getKind(), - multireductionOp.getSource(), loopState.back()); - resultArray[0] = reductionResult; - - maybeYieldValue(b, loc, resultArray); + movingOperation.push(curOp); } - } else { - // outter loop - auto nxtFor = reductionAxisGenerateForLoop( - b, groupIdx, reductionIdx + 1, loopState, inductionVars); - maybeYieldValue(b, loc, nxtFor->getResults()); + while (!opQueue.empty()) { + Operation *curOp = opQueue.front(); + if (isa(curOp)) { + opQueue.pop(); + continue; + } + break; + } + rewriteOperationAsVectorize(b, groupIdx, &movingOperation); + + moveOperationsToCurrentForBody(groupIdx, b, inductionVars, + currentLoopStateIdxMap, loopState, + movingOperation); + + auto reductionResult = makeArithReduction( + b, loc, multireductionOp.getKind(), multireductionOp.getSource(), + loopState.back()); + + movePostOpToCurrentAnchor( + rewriterOfFunc, anchorIdx, groupIdx, ValueRange(), + reductionResult.getParentBlock(), opQueue, movingOperation, + inductionVars, currentLoopStateIdxMap, loopState, + nextAnchorResults, forResultOrignalResultMap); + + nextAnchorResults.clear(); + nextAnchorResults.emplace_back(reductionResult); + nextAnchorResultsIdxMap[reductionResult] = 0; + forResultOrignalResultMap[reductionResult] = + multireductionOp->getResults()[0]; + getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, + nextAnchorResults, forResultOrignalResultMap); + maybeYieldValue(b, loc, nextAnchorResults); } }); @@ -931,347 +1179,282 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( // This function also call reduction axis for loop scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, - ValueRange &initArgs, llvm::SmallVector &inductionVars, - Value &originalWriteResult) { + DenseMap ¤tLoopStateIdxMap, const ValueRange &initArgs, + SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + SmallVector &inductionVars, + DenseMap &originalOperandLoopArgsMap, + DenseMap &forResultOrignalResultMap) { auto &rdCanonicalizer = getMultiRdCanonicalizers()[groupIdx]; - auto &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; - auto &vectorType = rdCanonicalizer.getSourceType(); - auto &accType = rdCanonicalizer.getAccType(); + vector::MultiDimReductionOp &multiReductionOp = + rdCanonicalizer.getCandidateOps()[0]; + VectorType vectorType = rdCanonicalizer.getSourceType(); IRRewriter rewriterOfFunc(func); - auto ¶llelAxis = rdCanonicalizer.getParallelAxis(); - auto isStandaloneOp = rdCanonicalizer.getIsStandAloneOp(); - auto lastDimReduction = rdCanonicalizer.hasLastDimReduction(); - const auto &loc = multiReductionOp.getLoc(); - auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); - auto forSteps = makeIndexArithConstantOp(opBuilder, loc, 1); + SmallVector ¶llelAxis = rdCanonicalizer.getParallelAxis(); + const Location &loc = multiReductionOp.getLoc(); + Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); + Value forSteps = makeIndexArithConstantOp(opBuilder, loc, 1); - // last dim reduction need to a generate dim=16 loop + // last dim reduction need to a generate dim=16 loop for fused with pre-op int dimSize = 0; if (parallelIdx == parallelAxis.size()) { - // TODO: need to consider data type lanes - dimSize = 16; + dimSize = getDataTypeMAXSIMDLength(vectorType); } else { dimSize = vectorType.getShape()[parallelAxis[parallelIdx]]; } - auto numIter = makeIndexArithConstantOp(opBuilder, loc, dimSize); + Value numIter = makeIndexArithConstantOp(opBuilder, loc, dimSize); // Create a loop and move vectorized operation into loops. - auto forOp = opBuilder.create( + return opBuilder.create( loc, zero, numIter, forSteps, initArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { inductionVars.emplace_back(iv); - auto &fusionStrategy = getFusionStrategy(); - auto &opIndexMap = fusionStrategy.getOpGroupIndexMap(); + VectorFusionStrategy &fusionStrategy = getFusionStrategy(); + DenseMap &opIndexMap = + fusionStrategy.getOpGroupIndexMap(); assert(opIndexMap.contains(multiReductionOp) && " Must constains multireduction operation."); - auto opIndex = opIndexMap[multiReductionOp]; - auto &opGroups = fusionStrategy.getOpGroups(); - auto opQueue = opGroups[opIndex]; - auto multiReductionAcc = multiReductionOp.getAcc(); - - if (parallelIdx == parallelAxis.size() - 1) { - // four kinds of group operations - // If fused a operation, it means multirection must just - // constains last dim to do the reduction. - // 1. just multireduction - // two cases: - // 1. constaints last dims - // for ... parallel axis: - // transfer_read from accSource tensor - // arith.constant : vector<16xf32> - // for ... 16: - // for ... reduction axis: - // add - // scalar = reduction vector - // scalar insert into accVector - // transfer_write accVector into emtpy tensor - // 2. not last dims - // for ... generate axis: - // transfer_read from accSource tensor - // transfer_read from source tensor - // accVector = add - // transfer_write accVector into emtpy tensor - // 2. prev-op + multireduction - // In this case, there will be no tensor.empty + transfer_read - // operation, but the multireduction should write in an empty - // tensor - // for ... parallel axis: - // accVector and related accVector operation should be here - // extract from accVector scalar - // airth.constant : vector<16xf32> - // for ... reduction axis: - // prevop source op - // add - // scalar = reduction vector - // scalar insert into accVector - // transfer_write accVector into empty tensor - // - // 3. post-op + multireduction - // for ... parallel axis: - // transferread from accSource tensor - // arith.constant : vector<16xf32> - // for ... reduction axis: - // add - // postOp - // post Op transferWrite emtpy tensor - // scalar = reduction vector - // scalar insert into accVector - // transfer_write accVector into emtpy tensor - // 4. prev-op + multireduction + post-op - // for ... parallel axis: - // accVector operation - // extract from accVector a scalar - // arith.constant : vector<16xf32> - // for ... reduction axis: - // prev-op source op and related source operation - // add - // postOp - // post Op transferWrite emtpy tensor - // scalar = reduction vector - // scalar insert into accVector - // transfer_write accVector into emtpy tensor - - if (isStandaloneOp) { - // read operation - IRMapping accReadMap; - auto accReadOp = multiReductionAcc.getDefiningOp(); - assert(mlir::isa(accReadOp)); - accReadMap.map(accReadOp->getOperand(0), loopState.back()); - - auto newAccReadOp = cloneReductionTransferRead( - multiReductionAcc, b, accReadMap, parallelAxis, inductionVars, - lastDimReduction, MultiReduceOpAxisKind::Parallel); - // constructe next for loop - Attribute initValueAttr; - getReductionInitAttr(multiReductionOp, initValueAttr); - - auto accVal = b.create( - loc, DenseElementsAttr::get(accType, {initValueAttr})); - - ValueRange newIterArgs(accVal); - - auto nxtFor = reductionAxisGenerateForLoop( - b, groupIdx, 0, newIterArgs, inductionVars); - - // insert accumulate value to original vector - auto accRes = nxtFor->getResults()[0]; - - Operation *reductionOp = b.create( - loc, multiReductionOp.getKind(), accRes); - auto insertOp = - b.create(loc, reductionOp->getResult(0), - newAccReadOp->getResults()[0], 0); - - // write vector back to tensor - vector::TransferWriteOp accWriteOp = nullptr; - for (auto [idx, x] : llvm::enumerate( - multiReductionOp->getResults()[0].getUsers())) { - if (idx == 0 && mlir::isa(x)) { - accWriteOp = mlir::dyn_cast(x); - break; - } - } - assert(accWriteOp && - " Not transfer_write operation. Current multireduction " - "operation may have wrong analysis IR."); - IRMapping accWriteindiceMap; - accWriteindiceMap.map(accWriteOp.getOperand(0), - insertOp->getResults()[0]); - auto writeResult = accWriteOp->getResults()[0]; - auto newAccWriteOp = makeNewTransferWriteOp( - writeResult, accWriteindiceMap, b, parallelAxis, inductionVars); - originalWriteResult = newAccWriteOp->getResult(0); - - maybeYieldValue(b, loc, newAccWriteOp->getResults()); - } else { - - auto &prevOps = getMultiRdCanonicalizers()[groupIdx].getPrevOps(); - auto &postOps = getMultiRdCanonicalizers()[groupIdx].getPostOps(); - auto &accRelatedOps = - getMultiRdCanonicalizers()[groupIdx].getAccRelatedOps(); - auto &sourceRelatedOps = - getMultiRdCanonicalizers()[groupIdx].getSourceRelatedOps(); - // prevOp + reduction op + postOp - // reduction op + postOp - getPrevOps(prevOps, opQueue, multiReductionOp); - getPostOps(postOps, opQueue, multiReductionOp); - bool hasPrevOps = hasOtherOperations(prevOps, multiReductionOp); - bool hasPostOps = hasOtherOperations(postOps, multiReductionOp); - - if (hasPostOps and !hasPrevOps) { - // multi_reduction + postOp - - } else { - // analysis acc related operation - classifySourceRelatedOps( - accRelatedOps, sourceRelatedOps, - multiReductionOp.getSource().getDefiningOp(), prevOps); - - rewriteOperationAsVectorize(b, groupIdx, accRelatedOps); - auto &grpArgs = getGroupOpIterArgs()[groupIdx]; - llvm::DenseMap operandIdxMap; - for (auto [idx, x] : llvm::enumerate(grpArgs)) { - operandIdxMap[x] = idx; - } - moveOperationsToCurrentForBody(groupIdx, b, inductionVars, - operandIdxMap, loopState, - accRelatedOps); - auto &grpResults = getGroupOpResults()[groupIdx]; - // next for loop - llvm::SmallVector iterArgsArray; - iterArgsArray.emplace_back(multiReductionAcc); - std::queue tmpSourceOps(sourceRelatedOps); - while (!tmpSourceOps.empty()) { - auto cur = tmpSourceOps.front(); - tmpSourceOps.pop(); - auto curResults = cur->getResults(); - for (auto x : curResults) { - if (grpResults.contains(x)) { - for (auto y : cur->getOperands()) { - if (grpArgs.contains(y)) { - iterArgsArray.emplace_back(y); - } - } - } + size_t opIndex = opIndexMap[multiReductionOp]; + SmallVector, 8> &opGroups = + fusionStrategy.getOpGroups(); + std::queue &opQueue = opGroups[opIndex]; + Value multiReductionAcc = multiReductionOp.getAcc(); + + // if (parallelIdx == parallelAxis.size() - 1) { + // four kinds of group operations + // If fused a operation, it means multirection must just + // constains last dim to do the reduction. + // 1. just multireduction + // two cases: + // 1. constaints last dims + // for ... parallel axis: + // transfer_read from accSource tensor + // arith.constant : vector<16xf32> + // for ... 16: + // for ... reduction axis: + // add + // scalar = reduction vector + // scalar insert into accVector + // transfer_write accVector into emtpy tensor + // 2. not last dims + // for ... generate axis: + // transfer_read from accSource tensor + // transfer_read from source tensor + // accVector = add + // transfer_write accVector into emtpy tensor + // 2. prev-op + multireduction + // In this case, there will be no tensor.empty + transfer_read + // operation, but the multireduction should write in an empty + // tensor + // for ... parallel axis: + // accVector and related accVector operation should be here + // extract from accVector scalar + // airth.constant : vector<16xf32> + // for ... reduction axis: + // prevop source op + // add + // scalar = reduction vector + // scalar insert into accVector + // transfer_write accVector into empty tensor + // + // 3. post-op + multireduction + // for ... parallel axis: + // transferread from accSource tensor + // arith.constant : vector<16xf32> + // for ... reduction axis: + // add + // postOp + // post Op transferWrite emtpy tensor + // scalar = reduction vector + // scalar insert into accVector + // transfer_write accVector into emtpy tensor + // 4. prev-op + multireduction + post-op + // for ... parallel axis: + // accVector operation + // extract from accVector a scalar + // arith.constant : vector<16xf32> + // for ... reduction axis: + // prev-op source op and related source operation + // add + // postOp + // post Op transferWrite emtpy tensor + // scalar = reduction vector + // scalar insert into accVector + // transfer_write accVector into emtpy tensor + + if (parallelIdx < parallelAxis.size()) { + // 1. move pre-Op to current body + DenseMap nextAnchorArgsIdxMap; + SmallVector nextAnchorArgs; + std::queue movedQueue; + movePreOpToCurrentAnchor( + parallelIdx, groupIdx, b, inductionVars, loopState, + currentLoopStateIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, + opQueue, movedQueue, originalOperandLoopArgsMap); + + if (parallelIdx == parallelAxis.size() - 1) { + std::queue checkAccQueue(movedQueue); + Value accInitVal; + while (!checkAccQueue.empty()) { + Operation *cur = checkAccQueue.front(); + checkAccQueue.pop(); + bool ok = false; + for (auto x : cur->getResults()) { + if (x == multiReductionAcc) { + accInitVal = x; + ok = true; + break; } } - ValueRange reductionAxisArgs(iterArgsArray); - auto nxtFor = parallelAxisGenerateForLoop( - b, groupIdx, parallelIdx + 1, reductionAxisArgs, - inductionVars, originalWriteResult); - if (hasPrevOps and !hasPostOps) { - // prevOp + reduction op - - } else { - - rewriteOperationAsVectorize(b, groupIdx, postOps); - moveOperationsToCurrentForBody(groupIdx, b, inductionVars, - operandIdxMap, loopState, - postOps); - - auto replaceIfFn = [&](OpOperand &use) { - return use.getOwner()->getBlock() == nxtFor->getBlock(); - }; - rewriterOfFunc.replaceOpUsesWithIf( - multiReductionOp, nxtFor->getResults()[0], replaceIfFn); - auto &originalResults = - getMultiRdCanonicalizers()[groupIdx].getOriginalOpResults(); - for (auto [idx, x] : llvm::enumerate(originalResults)) { - rewriterOfFunc.replaceOpUsesWithIf( - x.getDefiningOp(), nxtFor->getResults()[idx + 1], - replaceIfFn); - } - llvm::SmallVector resultsArray; - llvm::SmallDenseMap parallelIdxMap; - for (auto &x : grpResults) { - if (originalResults.contains(x)) { - auto &idxMap = rdCanonicalizer.getResultIdxMap(); - resultsArray.emplace_back(nxtFor->getResults()[idxMap[x]]); - } else { - resultsArray.emplace_back(x); - } - parallelIdxMap.insert({x, resultsArray.size() - 1}); - } - rdCanonicalizer.setResultIdxMap(parallelIdxMap); - maybeYieldValue(b, loc, resultsArray); - - // prepare iterArgs - } + if (ok) + break; } + assert(accInitVal && " Can't find accInit Value"); + nextAnchorArgs.emplace_back(accInitVal); + nextAnchorArgsIdxMap[accInitVal] = nextAnchorArgs.size() - 1; + originalOperandLoopArgsMap[multiReductionAcc] = accInitVal; } - } else { - if (parallelIdx == parallelAxis.size() && !isStandaloneOp) { - - Attribute initValueAttr; - getReductionInitAttr(multiReductionOp, initValueAttr); - - auto accVal = b.create( - loc, DenseElementsAttr::get(getVectorzedType(multiReductionOp), - {initValueAttr})); - llvm::SmallVector argsArray; - argsArray.emplace_back(accVal); - for (auto [idx, x] : llvm::enumerate(loopState)) { - if (idx == 0) - continue; - argsArray.emplace_back(x); + // 2. generate next for loop + scf::ForOp nxtFor = parallelAxisGenerateForLoop( + b, groupIdx, parallelIdx + 1, nextAnchorArgsIdxMap, + nextAnchorArgs, nextAnchorResults, nextAnchorResultsIdxMap, + inductionVars, originalOperandLoopArgsMap, + forResultOrignalResultMap); + + // 3. move postOp to current body + movePostOpToCurrentAnchor( + b, parallelIdx, groupIdx, nxtFor->getResults(), + nxtFor->getBlock(), opQueue, movedQueue, inductionVars, + currentLoopStateIdxMap, loopState, nextAnchorResults, + forResultOrignalResultMap); + + // 4. generate loop results + generateLoopResults(b, loc, parallelIdx, groupIdx, nextAnchorResults, + nextAnchorResultsIdxMap, nxtFor->getResults(), + movedQueue, forResultOrignalResultMap); + + } else if (parallelIdx == parallelAxis.size()) { + + // get accumualte value + Attribute initValueAttr; + getReductionInitAttr(multiReductionOp, initValueAttr); + auto accVal = b.create( + loc, DenseElementsAttr::get(getVectorzedType(multiReductionOp), + {initValueAttr})); + + DenseMap localAnchorResultsIdxMap; + SmallVector argsArray; + argsArray.emplace_back(accVal); + localAnchorResultsIdxMap[accVal] = 0; + size_t accLoopStateIdx = currentLoopStateIdxMap + [originalOperandLoopArgsMap[multiReductionAcc]]; + + for (auto [idx, x] : llvm::enumerate(loopState)) { + if (idx == accLoopStateIdx) { + continue; } - ValueRange newIterArgs(argsArray); - auto nxtFor = reductionAxisGenerateForLoop( - b, groupIdx, 0, newIterArgs, inductionVars); - // insert accumulate value to original vector - auto accRes = nxtFor->getResults()[0]; - - Operation *reductionOp = b.create( - loc, multiReductionOp.getKind(), accRes); - auto insertOp = b.create( - loc, reductionOp->getResult(0), initArgs[0], iv); - auto insertResult = insertOp->getResults()[0]; - - // result - llvm::SmallVector retResults; - retResults.emplace_back(insertResult); - for (auto [idx, x] : llvm::enumerate(nxtFor->getResults())) { - if (idx == 0) { - continue; - } - retResults.emplace_back(x); + argsArray.emplace_back(x); + localAnchorResultsIdxMap[accVal] = argsArray.size() - 1; + } + ValueRange reductionInitArgs(argsArray); + auto nxtFor = reductionAxisGenerateForLoop( + b, groupIdx, 0, parallelIdx, localAnchorResultsIdxMap, + reductionInitArgs, nextAnchorResults, nextAnchorResultsIdxMap, + inductionVars, forResultOrignalResultMap); + + // insert accumulate value to original vector + auto accRes = nxtFor->getResults()[0]; + + Operation *reductionOp = b.create( + loc, multiReductionOp.getKind(), accRes); + auto insertOp = b.create( + loc, reductionOp->getResult(0), accVal, iv); + + // generate loop result + SmallVector currentAnchorResults; + DenseMap currentResultMap; + DenseMap currentResultIdxMap; + + currentAnchorResults.emplace_back(insertOp->getResults()[0]); + // reduce axis for loop first result we has already processed above + currentResultMap[insertOp->getResults()[0]] = + multiReductionOp->getResults()[0]; + currentResultIdxMap[insertOp->getResults()[0]] = 0; + + for (auto [idx, x] : llvm::enumerate(nextAnchorResults)) { + if (idx == 0) { + continue; } - ValueRange retResultsArray(retResults); - maybeYieldValue(b, loc, retResultsArray); - - } else { - auto nxtFor = parallelAxisGenerateForLoop( - b, groupIdx, parallelIdx + 1, loopState, inductionVars, - originalWriteResult); - maybeYieldValue(b, loc, nxtFor->getResults()); + Value originalResult = forResultOrignalResultMap[x]; + size_t forResultIdx = nextAnchorResultsIdxMap[x]; + currentAnchorResults.emplace_back( + nxtFor->getResults()[forResultIdx]); + currentResultIdxMap[nxtFor->getResults()[forResultIdx]] = idx; + currentResultMap[nxtFor->getResults()[forResultIdx]] = + originalResult; } + nextAnchorResults.clear(); + nextAnchorResults = std::move(currentAnchorResults); + forResultOrignalResultMap = std::move(currentResultMap); + nextAnchorResultsIdxMap = std::move(currentResultIdxMap); + maybeYieldValue(b, loc, nextAnchorResults); } }); - return forOp; } scf::ForOp ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { - auto &grpArgs = getGroupOpIterArgs()[grpIdx]; - llvm::SmallVector forLoopArgs(grpArgs.begin(), grpArgs.end()); - llvm::SmallVector inductionVars; - ValueRange initArgs(forLoopArgs); - Value originalWriteResult; auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; - auto &rdResultMap = rdCanonicalizer.getResultIdxMap(); - IRRewriter rewriter(func); + // get current loop init args + SetVector &grpArgs = getGroupOpInitArgs()[grpIdx]; + SmallVector forLoopArgs(grpArgs.begin(), grpArgs.end()); + ValueRange initArgs(forLoopArgs); + DenseMap currentLoopStateIdxMap; + DenseMap nextAnchorResultsIdxMap; + // map original operation operand with loop args + DenseMap originalOperandLoopArgsMap, forResultOrignalResultMap; + for (auto [idx, val] : llvm::enumerate(initArgs)) { + currentLoopStateIdxMap[val] = idx; + originalOperandLoopArgsMap[val] = val; + } + SmallVector inductionVars; + IRRewriter rewriter(func); OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); + SmallVector nextAnchorResults; scf::ForOp forOp = parallelAxisGenerateForLoop( - opBuilder, grpIdx, 0, initArgs, inductionVars, originalWriteResult); + opBuilder, grpIdx, 0, currentLoopStateIdxMap, initArgs, nextAnchorResults, + nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, + forResultOrignalResultMap); auto replaceIfFn = [&](OpOperand &use) { return use.getOwner()->getBlock() == forOp->getBlock(); }; - for (auto &grpResult : getGroupOpResults()[grpIdx]) { - rewriter.replaceOpUsesWithIf(grpResult.getDefiningOp(), - forOp->getResults()[rdResultMap[grpResult]], - replaceIfFn); + for (auto x : nextAnchorResults) { + auto originalResult = forResultOrignalResultMap[x]; + rewriter.replaceOpUsesWithIf( + originalResult.getDefiningOp(), + forOp->getResults()[nextAnchorResultsIdxMap[x]], replaceIfFn); } - rewriter.replaceOp(getMultiRdCanonicalizers()[grpIdx].getCandidateOps()[0], forOp); + return forOp; } template -llvm::SmallVector &SpecialOperationCanonicalizer::getCandidateOps() { +SmallVector &SpecialOperationCanonicalizer::getCandidateOps() { return candidateRdOps; }; void MultiReductionCanonicalizer::initReductionAxis() { auto reductionAxisRange = getCandidateOps()[0].getReductionDims().getAsValueRange(); - auto reductionRange = llvm::to_vector<4>(llvm::map_range( + auto reductionRange = llvm::to_vector<4>(map_range( reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); reductionAxis.assign(reductionRange.begin(), reductionRange.end()); } @@ -1320,7 +1503,7 @@ void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { hasLastDimReduction(); }; -template void addDummyInit(llvm::SmallVector &canonicalizer) { +template void addDummyInit(SmallVector &canonicalizer) { canonicalizer.emplace_back(T({})); }; @@ -1339,9 +1522,9 @@ void CanonicalizerVectorOperation::dummyInitSpecialOperation() { } void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { - clearSpecialOperationCanonicalizers(); - auto &opGroups = getFusionStrategy().getOpGroups(); + SmallVector, 8> &opGroups = + getFusionStrategy().getOpGroups(); for (auto &grp : opGroups) { dummyInitSpecialOperation(); if (grp.empty()) { @@ -1369,30 +1552,29 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { } } -LogicalResult CanonicalizerVectorOperation::canonicalizeReductionOperation() { +void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { + // multireduction operation OpBuilder::InsertionGuard guard(rewriter); initSpeicalOperationCanonicalizers(); // traverse all groups - auto &multiRdCanonicalizers = getMultiRdCanonicalizers(); + llvm::SmallVector &multiRdCanonicalizers = + getMultiRdCanonicalizers(); + llvm::SmallVector &transposeCanonicalizers = + getTransposeCanonicalizers(); for (auto [groupId, rdCanonicalizer] : llvm::enumerate(multiRdCanonicalizers)) { - auto &candidateOps = rdCanonicalizer.getCandidateOps(); - if (candidateOps.empty()) { - continue; + SmallVector &rdOps = + rdCanonicalizer.getCandidateOps(); + if (!rdOps.empty()) { + // generate MultiReduction for loops + (void)generateMultiReductionForLoop(groupId); + } + SmallVector &transposeOps = + transposeCanonicalizers[groupId].getCandidateOps(); + if (!transposeOps.empty()) { + // (void) generateTransposeForLoop(groupId); } - // generate MultiReduction for loops - (void)generateMultiReductionForLoop(groupId); - } - return success(); -} - -void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { - // multireduction operation - auto result = canonicalizeReductionOperation(); - if (failed(result)) { - LDBG("Failed to canonicalize reduction operation\n"); - assert(0 && "Failed to canonicalize reduction operation"); } } @@ -1402,17 +1584,17 @@ void CanonicalizerVectorOperation::run() { // 1. Analysis the operation's operands and results // We need to analyze which operation results are needed by other // operations, and we need to pass these results correctly. Mapping the - // operation result value to forloop yeild result value. We can replace the - // operation operand as: map(operand, forloop yield result) -> operand = - // loop yield result We put all the operation result into this map. + // operation result value to forloop yeild result value. We can replace + // the operation operand as: map(operand, forloop yield result) -> operand + // = loop yield result We put all the operation result into this map. // 1.a. Find results which should be generated by current group for // using as operands to other operations? - // Traverse all operations. If the operand of operations in other groups or - // outside the group is the result of the current group operation, then the - // current operation needs to generate a result. We use `setvector` to save - // the results that need to be generated by the current group. + // Traverse all operations. If the operand of operations in other groups + // or outside the group is the result of the current group operation, then + // the current operation needs to generate a result. We use `setvector` to + // save the results that need to be generated by the current group. // 1.b. What operands are needed to find in the current group, and where // can they be obtained ? @@ -1420,24 +1602,23 @@ void CanonicalizerVectorOperation::run() { // Thanks to 1.a, we get the result generated by the operations of // each group, and this result will use `for loop yield` to generate a // new result. Since the scope of the parent block of mlir is covered - // the current operation, the current operation does not need to pass these - // `for loop results` to the `iterArgs` of the required `for loop`. It - // only needs to replace the operand of the current operation with the + // the current operation, the current operation does not need to pass + // these `for loop results` to the `iterArgs` of the required `for loop`. + // It only needs to replace the operand of the current operation with the // corresponding `for loop yield result`. // However, for some operations that are not DPS, we need to canonicalize // them. Canonicalization means that the operand of this operation is a // vector but we can't get this vector due to it locates in another block // which has a different scope. Therefore, it is necessary to write the - // vector results into a temporary tensor to save it. Then the vector needs - // to be read from the tensor before the current operation operate on it. - // Therefore, `empty tensor`, `transfer_write` and `transfer_read` need to - // be inserted at target place. + // vector results into a temporary tensor to save it. Then the vector + // needs to be read from the tensor before the current operation operate + // on it. Therefore, `empty tensor`, `transfer_write` and `transfer_read` + // need to be inserted at target place. // Query groupResultYeildSet to map operaion result value to scf.yield // result value. analysisGroupOperaion(); - printGroupOps(getFusionStrategy().getOpGroups()); // Speical Operation Canonicalization canonicalizeSpecialOperation(); @@ -1483,17 +1664,17 @@ void CanonicalizerVectorOperation::run() { // void setOperationCorrectOperand( Operation *op, const ValueRange &iterArgs, - const llvm::DenseMap &operandIdxMap, - const llvm::SmallVector &inductionVars, - const llvm::DenseMap &opPermuationMap) { + const DenseMap &operandIdxMap, ArrayRef inductionVars, + const DenseMap &opPermuationMap) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { if (operandIdxMap.contains(opd)) { op->setOperand(idx, iterArgs[operandIdxMap.at(opd)]); } } int offset = isa(op) ? 2 : 1; - if (llvm::dyn_cast(op) || - llvm::dyn_cast(op)) { + if (dyn_cast(op) || + dyn_cast(op)) { + assert(opPermuationMap.contains(op)); auto permutationMap = opPermuationMap.at(op); @@ -1509,10 +1690,9 @@ void setOperationCorrectOperand( scf::ForOp ForLoopGenerator::constructNestedForOp( const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, - const Location &loc, const ValueRange &iterArgs, const VectorType &type, - const llvm::ArrayRef &dims, - llvm::SmallVector &inductionVars, - const llvm::DenseMap &operandIdxMap) { + const Location &loc, const ValueRange &iterArgs, VectorType type, + const ArrayRef &dims, SmallVector &inductionVars, + const DenseMap &operandIdxMap) { const int loop_step = getDataTypeValidSteps(type); // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); @@ -1531,8 +1711,14 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( moveOperationsToCurrentForBody( groupIdx, b, inductionVars, operandIdxMap, loopState, getFusionStrategy().getOpGroups()[groupIdx]); - auto &resultSet = getGroupOpResults()[groupIdx]; - maybeYieldValue(b, loc, resultSet.getArrayRef()); + llvm::MapVector> &resultSet = + getGroupOpResults()[groupIdx]; + SmallVector results(resultSet.size()); + size_t idx = 0; + for (auto itr = resultSet.begin(); itr != resultSet.end(); itr++) { + results[idx++] = itr->first; + } + maybeYieldValue(b, loc, results); } else { // outter loop auto nxtFor = @@ -1566,7 +1752,8 @@ bool isSameVectorType(Operation *op1, Operation *op2) { return isSame; } -bool isCompatibleVectorType(Operation *op1, Operation *op2) { +bool VectorFusionStrategy::isCompatibleVectorType(Operation *op1, + Operation *op2) { auto type1 = getOperationVectorType(op1); auto type2 = getOperationVectorType(op2); if (failed(type1) || failed(type2)) { @@ -1597,34 +1784,9 @@ bool isCompatibleVectorType(Operation *op1, Operation *op2) { return isCompatible; } -bool isPartialCompatible(Operation *op1, Operation *op2) { - auto type1 = getOperationVectorType(op1); - auto type2 = getOperationVectorType(op2); - if (failed(type1) || failed(type2)) { - return false; - } - auto sp1 = type1.value(); - auto sp2 = type2.value(); - // must be total same - if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { - return false; - } - bool isCompatible = true; - auto min_rank = std::min(sp1.getRank(), sp2.getRank()); - // from front to back - for (long i = 0; i < min_rank; i++) { - if (sp1.getDimSize(i) != sp2.getDimSize(i)) { - isCompatible = false; - break; - } - } - - return isCompatible; -} - /// which axis do the shape cast in source shape a void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, - llvm::SmallVector &res) { + SmallVector &res) { unsigned rankA = a.size(); unsigned rankB = b.size(); assert(rankA < rankB && "May be invalid shape cast operation."); @@ -1633,7 +1795,7 @@ void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape // casted to a 0-d vector. - if (rankA == 0 && llvm::all_of(b, isOne)) { + if (rankA == 0 && all_of(b, isOne)) { for (size_t i = 0; i < a.size(); i++) { res.emplace_back(i); } @@ -1659,9 +1821,9 @@ void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, // Handle the case when trailing dimensions are of size 1. // Include them into the contiguous sequence. - if (i < rankA && llvm::all_of(a.slice(i), isOne)) + if (i < rankA && all_of(a.slice(i), isOne)) i = rankA; - if (j < rankB && llvm::all_of(b.slice(j), isOne)) + if (j < rankB && all_of(b.slice(j), isOne)) j = rankB; } @@ -1678,7 +1840,7 @@ bool isScalar(Type type) { } void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output, - llvm::SmallVector &bcAxis) { + SmallVector &bcAxis) { auto inputShape = input.getShape(); auto outputShape = output.getShape(); // following auto_broadcast semantics @@ -1699,13 +1861,13 @@ void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output, } } -void getOperationDataAxis(Operation *op, llvm::SmallVector &dataAxis) { +void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { return TypeSwitch(op) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { auto rdDimsRange = multiReductionOp.getReductionDims() .getAsValueRange(); - auto reductionDims = llvm::to_vector<4>(llvm::map_range( + auto reductionDims = llvm::to_vector<4>(map_range( rdDimsRange, [](const APInt &a) { return a.getZExtValue(); })); dataAxis.assign(reductionDims.begin(), reductionDims.end()); }) @@ -1762,9 +1924,9 @@ bool hasDataDependency(Operation *op1, Operation *op2) { if (disableSpecialOp) { return true; } - auto hasSameAxis = [](const llvm::SmallVector &dims1, - const llvm::SmallVector &dims2) { - llvm::DenseSet checkSet(dims2.begin(), dims2.end()); + auto hasSameAxis = [](const SmallVector &dims1, + const SmallVector &dims2) { + DenseSet checkSet(dims2.begin(), dims2.end()); for (auto x : dims1) { if (checkSet.contains(x)) { return true; @@ -1775,7 +1937,7 @@ bool hasDataDependency(Operation *op1, Operation *op2) { auto res = TypeSwitch(op1) .Case([&](vector::ShapeCastOp shapeCastOp) { - llvm::SmallVector dims1, dims2; + SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); return hasSameAxis(dims1, dims2); @@ -1784,10 +1946,10 @@ bool hasDataDependency(Operation *op1, Operation *op2) { [&](vector::MultiDimReductionOp multiReductionOp) { // has two cases: op1 is special operation, op2 is normal // operation op1 and op2 is both speicial operation - llvm::SmallVector dims2, reductionDims, parallelDims; + SmallVector dims2, reductionDims, parallelDims; getOperationDataAxis(op1, reductionDims); getOperationDataAxis(op2, dims2); - llvm::DenseSet checkSet(dims2.begin(), dims2.end()); + DenseSet checkSet(dims2.begin(), dims2.end()); auto op2VectorType = getOperationVectorType(op2); if (!isSpecialOp(op2)) { if (isSameVectorType(op1, op2)) { @@ -1825,20 +1987,11 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return false; } - // else { - // // TODO: reduce operation fused with other special - // operation if (mlir::isa(op2)) { - // return true; - // } else if (mlir::isa(op2)) { - // return true; - // } - // //... - // } return true; }) .Case([&](vector::BroadcastOp broadcastOp) { - llvm::SmallVector dims1, dims2; + SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); return true; @@ -1849,7 +2002,7 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return true; }) .Case([&](vector::TransposeOp transposeOp) { - llvm::SmallVector dims1, dims2; + SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); return true; @@ -1885,13 +2038,33 @@ bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { return false; } +void VectorFusionStrategy::updateGroupBitgestVectorType(VectorType vectorType) { + int64_t rank = vectorType.getRank(); + llvm::SmallDenseMap &groupVectorType = + getGroupBiggestRankVectorType(); + + if (groupVectorType.contains(opGroups.size() - 1)) { + VectorType bigestType = groupVectorType[opGroups.size() - 1]; + if (bigestType.getRank() < rank) { + groupVectorType[opGroups.size() - 1] = vectorType; + } + return; + } + + groupVectorType[opGroups.size() - 1] = vectorType; +} + void VectorFusionStrategy::addOperationToGroup(Operation *op) { assert(op); + VectorType vectorType = getOperationVectorType(op).value(); if (isNeedNewGroup(op)) { opGroups.emplace_back(std::queue()); + } else { + updateGroupBitgestVectorType(vectorType); } opGroups.back().push(op); opGroupIndexMap[op] = opGroups.size() - 1; + opAnchorPos[op] = getOperationVectorType(op)->getRank() - 1; } // We classify the operations we are interested in after filtering. Operations @@ -1939,11 +2112,10 @@ Value setOutGroupOperationOperandResult(Operation *op, } } else if (isa(resultElementType)) { initValueAttr = FloatAttr::get( - resultElementType, - llvm::cast(value).getValueAsDouble()); + resultElementType, cast(value).getValueAsDouble()); } else { initValueAttr = IntegerAttr::get( - resultElementType, llvm::cast(value).getInt()); + resultElementType, cast(value).getInt()); } auto cntOp = rewriter.create( @@ -1955,9 +2127,8 @@ Value setOutGroupOperationOperandResult(Operation *op, return ret; } -void setOperationOperandResult( - Operation *op, const VectorType &newOperandType, - const llvm::DenseMap &opMap) { +void setOperationOperandResult(Operation *op, const VectorType &newOperandType, + const DenseMap &opMap) { for (auto [idx, x] : llvm::enumerate(op->getOperands())) { if (mlir::dyn_cast(x.getType())) { if (!opMap.contains(x.getDefiningOp())) { @@ -2016,13 +2187,13 @@ void ForLoopGenerator::createNewConstantOp( /// Rewrite the operations in the group to vectorized form. void ForLoopGenerator::rewriteOperationAsVectorize( - OpBuilder &rewriter, size_t groupId, const std::queue &queue) { + OpBuilder &rewriter, size_t groupId, const std::queue *queue) { const std::queue groupOps = - queue.empty() ? getFusionStrategy().getOpGroups()[groupId] : queue; - const llvm::DenseMap &opMap = + !queue ? getFusionStrategy().getOpGroups()[groupId] : *queue; + + const DenseMap &opMap = getFusionStrategy().getOpGroupIndexMap(); - llvm::DenseMap &opPermuationMap = - getOpPermuationMap(); + DenseMap &opPermuationMap = getOpPermuationMap(); std::queue transformQueue(groupOps); size_t groupSteps = getFusionStrategy().getGroupMaxSteps()[groupId]; @@ -2069,6 +2240,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( }) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { + multiReductionOp.dump(); llvm::llvm_unreachable_internal( "It should not appear this operation."); return failure(); @@ -2094,7 +2266,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( } mlir::FailureOr getOperationOperateTensor(Operation *op) { - return llvm::TypeSwitch>(op) + return TypeSwitch>(op) .Case( [&](vector::TransferWriteOp transferWriteOp) { LDBG(" DPS operation : " << *op << "\n"); @@ -2110,11 +2282,9 @@ mlir::FailureOr getOperationOperateTensor(Operation *op) { }); } -void updateOpOperandResultInGroups( - llvm::SmallVector, 8> &opGroups, - llvm::DenseMap &opGroupIndexMap, size_t opGid, - Operation *op, Value &init, const Value &result = Value()) { - std::queue tmpOpQueue(opGroups[opGid]); +void VectorOperationAnalysizer::updateOpOperandResultInGroups( + size_t opGid, Operation *op, Value &init, const Value &result) { + std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[opGid]); std::queue newOpQueue; while (!tmpOpQueue.empty()) { auto curOp = tmpOpQueue.front(); @@ -2122,39 +2292,50 @@ void updateOpOperandResultInGroups( if (curOp == op) { if (!failed(getOperationVectorType(init.getDefiningOp()))) { newOpQueue.push(init.getDefiningOp()); - opGroupIndexMap[init.getDefiningOp()] = opGid; + getFusionStrategy().getOpGroupIndexMap()[init.getDefiningOp()] = opGid; + getFusionStrategy().getOpAnchorPos()[init.getDefiningOp()] = + getFusionStrategy().getOpAnchorPos()[op]; } newOpQueue.push(op); if (result && !failed(getOperationVectorType(result.getDefiningOp()))) { newOpQueue.push(result.getDefiningOp()); - opGroupIndexMap[result.getDefiningOp()] = opGid; + getFusionStrategy().getOpGroupIndexMap()[result.getDefiningOp()] = + opGid; + getFusionStrategy().getOpAnchorPos()[result.getDefiningOp()] = + getFusionStrategy().getOpGroupIndexMap()[op]; } } else { newOpQueue.push(curOp); } } - opGroups[opGid] = newOpQueue; + getFusionStrategy().getOpGroups()[opGid] = newOpQueue; } void VectorFusionStrategy::run() { classifyOperations(); } void VectorOperationAnalysizer::generateEmptyTensorAndWrite( - Operation *sourceOp, llvm::DenseMap> - &srcOpCanoniclizedMap) { - auto &opGroups = getFusionStrategy().getOpGroups(); - auto &opGroupIndexMap = getFusionStrategy().getOpGroupIndexMap(); - auto &groupOpIterArgs = getGroupOpIterArgs(); - auto &groupOpResults = getGroupOpResults(); - auto sourceOpGid = opGroupIndexMap[sourceOp]; + Operation *sourceOp, + DenseMap> &srcOpCanoniclizedMap, + size_t anchorPos, ReturnTypeKind retKind) { + DenseMap &opGroupIndexMap = + getFusionStrategy().getOpGroupIndexMap(); + SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); + SmallVector>, 8> + &groupOpResults = getGroupOpResults(); + size_t sourceOpGid = opGroupIndexMap[sourceOp]; auto [resultTensor, result] = canonicalizeSourceOperation(sourceOp); + auto writeOp = result.getDefiningOp(); srcOpCanoniclizedMap.insert({sourceOp, {resultTensor, result}}); - updateOpOperandResultInGroups(opGroups, opGroupIndexMap, sourceOpGid, - sourceOp, resultTensor, result); - groupOpIterArgs[sourceOpGid].insert(resultTensor); - groupOpResults[sourceOpGid].insert(result); + updateOpOperandResultInGroups(sourceOpGid, sourceOp, resultTensor, result); + groupOpInitArgs[sourceOpGid].insert(resultTensor); + groupOpResults[sourceOpGid].insert({result, {retKind, anchorPos}}); + groupOpResults[sourceOpGid].back().first.dump(); + getFusionStrategy().getOpAnchorPos()[writeOp] = + cast(writeOp).getVectorType().getRank() - 1; + getOpPermuationMap()[writeOp] = writeOp.getPermutationMap(); } void VectorOperationAnalysizer::analysisEmptyGroupAndMaxSteps() { @@ -2163,7 +2344,7 @@ void VectorOperationAnalysizer::analysisEmptyGroupAndMaxSteps() { // If the group operations do not have result need to be returned, these are // useless code. - for (auto [idx, grp] : enumerate(opGroups)) { + for (auto [idx, grp] : llvm::enumerate(opGroups)) { if (groupOpResults[idx].empty()) { std::queue().swap(grp); } @@ -2194,53 +2375,91 @@ void VectorOperationAnalysizer::analysisEmptyGroupAndMaxSteps() { // analysis operation result of current group whether needed by other // operation which out of current group void VectorOperationAnalysizer::analysisGroupOperationResults() { - llvm::DenseMap> srcOpCanoniclizedMap; - llvm::DenseSet movedOperationSet; - auto &opGroups = getFusionStrategy().getOpGroups(); - auto &opGroupIndexMap = getFusionStrategy().getOpGroupIndexMap(); - auto &groupOpIterArgs = getGroupOpIterArgs(); - auto &groupOpResults = getGroupOpResults(); + DenseMap> srcOpCanoniclizedMap; + DenseSet movedOperationSet; + DenseMap &opGroupIndexMap = + getFusionStrategy().getOpGroupIndexMap(); + SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); + SmallVector>, 8> + &groupOpResults = getGroupOpResults(); + DenseMap &OpAnchorPos = + getFusionStrategy().getOpAnchorPos(); + + auto updateReturnResultKind = [&](Operation *sourceOp, size_t sourceOpGid, + ReturnTypeKind rtKind) { + Value sourceResult; + if (srcOpCanoniclizedMap.contains(sourceOp)) { + sourceResult = srcOpCanoniclizedMap[sourceOp].second; + } else { + sourceResult = sourceOp->getResults()[0]; + } + size_t srcOpAnchor = groupOpResults[sourceOpGid][sourceResult].second; + ReturnTypeKind prevRtKind = groupOpResults[sourceOpGid][sourceResult].first; + srcOpAnchor = std::min(srcOpAnchor, OpAnchorPos[sourceOp]); + if (prevRtKind != rtKind) { + groupOpResults[sourceOpGid][sourceResult] = + std::make_pair(ReturnTypeKind::RT_Both, srcOpAnchor); + } else { + groupOpResults[sourceOpGid][sourceResult] = + std::make_pair(rtKind, srcOpAnchor); + } + groupOpResults[sourceOpGid].back().first.dump(); + }; + func.walk([&](Operation *op) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { - auto sourceOp = opd.getDefiningOp(); + Operation *sourceOp = opd.getDefiningOp(); if (opGroupIndexMap.contains(sourceOp)) { auto sourceOpGid = opGroupIndexMap[sourceOp]; bool notInSameGroup = opGroupIndexMap.contains(op) && sourceOpGid != opGroupIndexMap[op]; - bool outOfGroup = !opGroupIndexMap.contains(op); - if (notInSameGroup or outOfGroup) { + bool outOfGroup = !opGroupIndexMap.contains(op); + // Different anchor in same group and source operation is in inner loop, + // we need to get source operation's result + bool inSameGroupNeedReturn = + !notInSameGroup and OpAnchorPos[sourceOp] > OpAnchorPos[op]; + ReturnTypeKind rtKind = inSameGroupNeedReturn + ? ReturnTypeKind::RT_InGroup + : ReturnTypeKind::RT_OutGroup; + + if (notInSameGroup or outOfGroup or inSameGroupNeedReturn) { // update init iterargs auto dstRet = getOperationOperateTensor(sourceOp); // need to generate tensor.emtpy and vector.transfer_write, write // operand to tensor and read operand from the tensor, generate // vector.transfer_read if (failed(dstRet)) { - // already generate result tensor + // already generate result tensor, special operation do the + // transformation by itself + if (isSpecialOp(sourceOp) and inSameGroupNeedReturn) { + continue; + } if (!srcOpCanoniclizedMap.contains(sourceOp)) { - generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap); + generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap, + OpAnchorPos[sourceOp], rtKind); + } else { + // udpate result return type + updateReturnResultKind(sourceOp, sourceOpGid, rtKind); } auto opInit = canonicalizeCurrentOperation( op, srcOpCanoniclizedMap[sourceOp].second, idx); - updateOpOperandResultInGroups(opGroups, opGroupIndexMap, - opGroupIndexMap[op], op, opInit); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); } else { // if source operation is transfer_read, we need to generate a // same transfer_read operation like source operation. if (mlir::isa(sourceOp)) { - auto transferReadOp = - mlir::dyn_cast(sourceOp); + auto transferReadOp = cast(sourceOp); auto opInit = canonicalizeCurrentOperation(op, dstRet.value(), idx, &transferReadOp); - updateOpOperandResultInGroups(opGroups, opGroupIndexMap, - opGroupIndexMap[op], op, opInit); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); } else { - groupOpIterArgs[sourceOpGid].insert(dstRet.value()); - groupOpResults[sourceOpGid].insert(opd); + groupOpInitArgs[sourceOpGid].insert(dstRet.value()); + updateReturnResultKind(sourceOp, sourceOpGid, rtKind); } } } @@ -2253,6 +2472,7 @@ void VectorOperationAnalysizer::analysisGroupOperationResults() { } }); analysisEmptyGroupAndMaxSteps(); +#undef RESULT_RETURN_TYPE LDBG("Complete analysis group operation results\n"); } @@ -2262,29 +2482,30 @@ void VectorOperationAnalysizer::analysisGroupOperaion() { } mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( - const size_t groupId, IRRewriter &rewriter, const VectorType &vectorType) { + const size_t groupId, IRRewriter &rewriter, VectorType vectorType) { auto &resultSet = getGroupOpResults(); - auto &dstOperandSet = getGroupOpIterArgs()[groupId]; + auto &initArgs = getGroupOpInitArgs()[groupId]; assert(!resultSet.empty() && "Expected non-empty value"); // prepare for loop iterargs - llvm::SmallVector operands; - llvm::DenseMap operandIdxMap; - for (auto [idx, x] : llvm::enumerate(dstOperandSet)) { + SmallVector operands; + DenseMap operandIdxMap; + for (auto [idx, x] : llvm::enumerate(initArgs)) { operands.emplace_back(x); operandIdxMap[x] = operands.size() - 1; } - ValueRange iterArgs(operands); + ValueRange forIterArgs(operands); auto shapes = vectorType.getShape(); - llvm::SmallVector inductionVars; + SmallVector inductionVars; // generate for loop auto forOp = constructNestedForOp( - 0, groupId, rewriter, rewriter.getUnknownLoc(), iterArgs, vectorType, + 0, groupId, rewriter, rewriter.getUnknownLoc(), forIterArgs, vectorType, shapes, inductionVars, operandIdxMap); return forOp; } -void updateLoopResultUses(llvm::SetVector &opResults, - scf::ForOp *forOp) { +void updateLoopResultUses( + llvm::MapVector> &opResults, + scf::ForOp *forOp) { if (opResults.empty()) { return; } @@ -2292,14 +2513,15 @@ void updateLoopResultUses(llvm::SetVector &opResults, OpBuilder::InsertionGuard g(rewriter); // Only different group operation operand need to be replaced due to same // group operation should directly use original operand. - auto producerOp = opResults.front().getDefiningOp(); + + Operation *producerOp = opResults.begin()->first.getDefiningOp(); auto needToReplaced = [&](OpOperand &operand) { return producerOp->getBlock() != operand.getOwner()->getBlock(); }; // update loop result uses for (auto [retIdx, rt] : llvm::enumerate(opResults)) { - producerOp = rt.getDefiningOp(); - rewriter.replaceUsesWithIf(rt, forOp->getResult(retIdx), needToReplaced); + rewriter.replaceUsesWithIf(rt.first, forOp->getResult(retIdx), + needToReplaced); } } @@ -2326,18 +2548,14 @@ void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { return; } auto &groupOpResults = getGroupOpResults(); - auto getType = getOperationVectorType(grp.front()); - if (failed(getType)) { - LDBG("Failed to get vector type for operation: " << *grp.front() << "\n"); - return; - } - auto opShapes = getType.value(); + VectorType groupType = + getFusionStrategy().getGroupBiggestRankVectorType()[idx]; IRRewriter rewriter(grp.back()); rewriter.setInsertionPointAfter(grp.back()); // 1. Rewrite operation as vectorized form // 2. Generate loop rewriteOperationAsVectorize(rewriter, idx); - auto forOp = generateVectorizedForLoop(idx, rewriter, opShapes); + auto forOp = generateVectorizedForLoop(idx, rewriter, groupType); // special operation do not need to change anything if (failed(forOp)) { return; diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 60d572d2d..6e0561ac3 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -141,77 +141,77 @@ func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf return %3 : tensor<8192x16384xf32> } -#map = affine_map<(d0) -> (d0 * 64)> -#map1 = affine_map<(d0) -> (d0 * 128)> -#map2 = affine_map<(d0) -> (d0 floordiv 16)> -#map3 = affine_map<(d0) -> (d0 floordiv 32)> -#map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> -#map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> - func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { - %c32 = arith.constant 32 : index - %c512 = arith.constant 512 : index - %c128 = arith.constant 128 : index - %c64 = arith.constant 64 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x256xbf16> - %1 = tensor.empty() : tensor<512x256xbf16> - %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { - %3 = affine.apply #map(%arg3) - %4 = affine.apply #map1(%arg4) - %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> - %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> - %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> - %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> - %11 = affine.apply #map2(%arg28) - %12 = affine.apply #map3(%arg24) - %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> - %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> - %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> - %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> - %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> - %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %15 = affine.apply #map4(%arg12, %arg24, %arg4) - %16 = affine.apply #map5(%arg8, %arg20, %arg3) - %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> - %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> - %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] - %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.forall.in_parallel { - tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - } - } - return %2#2 : tensor<128x256xbf16> - } +// #map = affine_map<(d0) -> (d0 * 64)> +// #map1 = affine_map<(d0) -> (d0 * 128)> +// #map2 = affine_map<(d0) -> (d0 floordiv 16)> +// #map3 = affine_map<(d0) -> (d0 floordiv 32)> +// #map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> +// #map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> +// func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { +// %c32 = arith.constant 32 : index +// %c512 = arith.constant 512 : index +// %c128 = arith.constant 128 : index +// %c64 = arith.constant 64 : index +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : bf16 +// %0 = tensor.empty() : tensor<128x256xbf16> +// %1 = tensor.empty() : tensor<512x256xbf16> +// %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { +// %3 = affine.apply #map(%arg3) +// %4 = affine.apply #map1(%arg4) +// %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> +// %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> +// %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> +// %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> +// %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> +// %11 = affine.apply #map2(%arg28) +// %12 = affine.apply #map3(%arg24) +// %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> +// %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> +// %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> +// %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> +// %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> +// %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %15 = affine.apply #map4(%arg12, %arg24, %arg4) +// %16 = affine.apply #map5(%arg8, %arg20, %arg3) +// %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> +// %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> +// %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] +// %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.forall.in_parallel { +// tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// } +// } +// return %2#2 : tensor<128x256xbf16> +// } From 7a71b25cf09a8ef709743d83d0df7b0811f23b24 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 22 Jul 2024 20:24:25 +0800 Subject: [PATCH 17/66] fix reduce bug --- include/gc/Transforms/TilingVector.h | 68 +-- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 467 +++++++++++------- .../gc/transforms/cpu-vetor-distribution.mlir | 148 +++--- 3 files changed, 406 insertions(+), 277 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index a1e1b84e8..38bc51d3c 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -53,6 +53,7 @@ void setOperationCorrectOperand( const llvm::DenseMap &operandIdxMap, ArrayRef inductionVars, const llvm::DenseMap &opPermuationMap); +mlir::FailureOr getOperationOperateTensor(Operation *op); struct HardWareInfo { bool favx512f = true; @@ -378,6 +379,24 @@ class CanonicalizerCommonUsedData : public TypeHelper { // other methods bool isGroupHasSpecialOperation(const size_t grpIdx); + + void generateEmptyTensorAndWrite( + Operation *sourceOp, + llvm::DenseMap> + &srcOpCanoniclizedMap, + size_t anchorPos, ReturnTypeKind retKind); + + void updateOpOperandResultInGroups(size_t opGid, Operation *op, Value &init, + const Value &result = Value()); + + Value + canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, + size_t operandIdx, + vector::TransferReadOp *srcReadOp = nullptr); + + Operation * + createTransferReadOpBefore(Operation *op, const Value &operand, + vector::TransferReadOp *srcReadOp = nullptr); }; class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { @@ -424,7 +443,8 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { llvm::DenseMap ¤tLoopStateIdxMap, llvm::DenseMap &nextAnchorArgsIdxMap, llvm::SmallVector &nextAnchorArgs, - DenseMap &originalOperandLoopArgsMap); + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap); void getOperationInCurrentAnchor(const size_t anchorIdx, std::queue &fromQueue, @@ -446,7 +466,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const ValueRange &loopState, const llvm::SmallVector &nextAnchorResults, DenseMap &forResultOrignalResultMap); - + void movePreOpToCurrentAnchor(const size_t anchorIdx, const size_t groupIdx, OpBuilder &b, ArrayRef inductionVars, @@ -456,12 +476,13 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { llvm::SmallVector &nextAnchorArgs, std::queue &candidateQueue, std::queue &movedQueue, - DenseMap &originalOperandLoopArgsMap); + DenseMap &originalOperandLoopArgsMap, + DenseMap &LoopArgsoriginalOperandMap); void replaceOperationsWithForLoopResult( IRRewriter &rewrite, const ValueRange &forResults, const Block *forBlock, const llvm::SmallVector &nextAnchorResults, - const std::queue movingOperations, + const std::queue &movingOperations, DenseMap &forResultOrignalResultMap); // multireduction forloop methods scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); @@ -469,11 +490,14 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, const int anchorIdx, llvm::DenseMap ¤tLoopStateIdxMap, const ValueRange &initArgs, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap, llvm::SmallVector &nextAnchorResults, llvm::DenseMap &nextAnchorResultsIdxMap, llvm::SmallVector &inductionVars, - DenseMap &forResultOrignalResultMap); - + DenseMap &forResultOrignalResultMap, + DenseMap &originalResultForResultMap); + scf::ForOp parallelAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, llvm::DenseMap ¤tLoopStateIdxMap, @@ -482,13 +506,9 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { llvm::DenseMap &nextAnchorResultsIdxMap, llvm::SmallVector &inductionVars, DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap, DenseMap &forResultOrignalResultMap); - scf::ForOp - reductionAxisGenerateForLoop(OpBuilder &opBuilder, const int groupIdx, - const size_t reductionIdx, ValueRange &initArgs, - llvm::SmallVector &inductionVars); - vector::TransferReadOp cloneReductionTransferRead( Value &source, OpBuilder &b, IRMapping &readMap, const llvm::SmallVector ¶llelAxis, @@ -496,36 +516,24 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel); }; -class VectorOperationAnalysizer : virtual public CanonicalizerCommonUsedData { +class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { private: func::FuncOp func; public: - virtual ~VectorOperationAnalysizer(){}; - VectorOperationAnalysizer() {} - VectorOperationAnalysizer(func::FuncOp &func) : func(func) {} - void generateEmptyTensorAndWrite( - Operation *sourceOp, - llvm::DenseMap> - &srcOpCanoniclizedMap, - size_t anchorPos, ReturnTypeKind retKind); + virtual ~VectorOperationAnalyzer(){}; + VectorOperationAnalyzer() {} + VectorOperationAnalyzer(func::FuncOp &func) : func(func) {} + void setAnalysisFunc(func::FuncOp &func) { this->func = func; } void analysisEmptyGroupAndMaxSteps(); void analysisGroupOperaion(); void analysisGroupOperationResults(); - Value - canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, - size_t operandIdx, - vector::TransferReadOp *srcReadOp = nullptr); - void updateOpOperandResultInGroups(size_t opGid, Operation *op, Value &init, - const Value &result = Value()); - Operation * - createTransferReadOpBefore(Operation *op, const Value &operand, - vector::TransferReadOp *srcReadOp = nullptr); + void specialOperationAnchorRectify(); }; /// Vectorize vector operation with target machines simd instructions. class CanonicalizerVectorOperation : virtual public ForLoopGenerator, - VectorOperationAnalysizer { + VectorOperationAnalyzer { private: func::FuncOp func; IRRewriter rewriter; diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index b515b6593..1598fd7f9 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -153,6 +153,28 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { } } +FailureOr createArithSplatConstantOp(IRRewriter &rewriter, + const Location &loc, + const ElementsAttr &valueType, + VectorType &newOperandType) { + + if (valueType.isSplat()) { + Value res; + if (mlir::isa(valueType.getElementType())) { + res = rewriter.create( + loc, + FloatAttr::get(newOperandType, valueType.getSplatValue())); + } else { + res = rewriter.create( + loc, + IntegerAttr::get(newOperandType, valueType.getSplatValue())); + } + return res; + } + + return failure(); +} + mlir::FailureOr getOperationVectorType(Operation *op) { if (!op) { return failure(); @@ -181,10 +203,12 @@ mlir::FailureOr getOperationVectorType(Operation *op) { return multiReductionOp.getSourceVectorType(); }) .Case( - [&](arith::ConstantOp constantOp) { return failure(); }) + [&](arith::ConstantOp constantOp) -> mlir::FailureOr { + return failure(); + }) .Default([&](Operation *op) -> mlir::FailureOr { if (!op->getResults().empty()) { - auto t = mlir::dyn_cast(op->getResultTypes().front()); + auto t = dyn_cast(op->getResultTypes().front()); if (t) { if (isDynamicType(t)) { return failure(); @@ -605,7 +629,7 @@ Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { /*inBounds=*/inBoundsVal); } -Operation *VectorOperationAnalysizer::createTransferReadOpBefore( +Operation *CanonicalizerCommonUsedData::createTransferReadOpBefore( Operation *op, const Value &operand, vector::TransferReadOp *srcReadOp) { auto operandType = cast(operand.getType()); @@ -661,7 +685,7 @@ canonicalizeSourceOperation(Operation *op) { return std::make_pair(resultTensor, writeOp->getResults()[0]); } -[[nodiscard]] Value VectorOperationAnalysizer::canonicalizeCurrentOperation( +[[nodiscard]] Value CanonicalizerCommonUsedData::canonicalizeCurrentOperation( Operation *op, const Value &transferReadOperand, size_t operandIdx, vector::TransferReadOp *srcReadOp) { // transfer_read operation @@ -755,6 +779,24 @@ void classifySourceRelatedOps(std::queue &accRelatedOps, } } +/// get multi_reduction operation accumulate value source related operations +/// \param srcOp accumulate value source operation +void classifyAccRelatedOps(std::queue &accRelatedOps, + std::queue &sourceRelatedOps, + Operation *srcOp, std::queue &prevOps) { + DenseSet srcOpsSet; + getOpSourceOps(srcOp, srcOpsSet); + while (!prevOps.empty()) { + auto op = prevOps.front(); + prevOps.pop(); + if (isSrcRelated(srcOpsSet, op) or op == srcOp) { + accRelatedOps.push(op); + } else { + sourceRelatedOps.push(op); + } + } +} + void updateReduceReadWriteOperationOperand( const SmallVector &inductionVars, const SmallVector ¶llelAxis, Operation *op, @@ -877,20 +919,42 @@ void ForLoopGenerator::getResultInCurrentOps( } } } + +/// update loop args related status +/// \param nextAnchorArgsIdxMap anchor args index map +/// \param nextOperandArgsMap original value to next loop args map +/// \param nextArgsOperandMap next loop args to original value map +void updateCurrentArgsStatus(const ValueRange &loopState, + const size_t loopStateIdx, + SmallVector &nextAnchorArgs, + Value originalValue, + DenseMap &nextAnchorArgsIdxMap, + DenseMap &nextOperandArgsMap, + DenseMap &nextArgsOperandMap) { + Value currentArgs = loopState[loopStateIdx]; + nextAnchorArgs.emplace_back(currentArgs); + nextAnchorArgsIdxMap[currentArgs] = nextAnchorArgs.size() - 1; + nextOperandArgsMap[originalValue] = currentArgs; + nextArgsOperandMap[currentArgs] = originalValue; +} + void ForLoopGenerator::getInitArgsToNextAnchor( const size_t anchorIdx, const size_t groupId, const std::queue &nextOperations, const ValueRange &loopState, DenseMap ¤tLoopStateIdxMap, DenseMap &nextAnchorArgsIdxMap, SmallVector &nextAnchorArgs, - DenseMap &originalOperandLoopArgsMap) { + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap) { DenseMap &opAnchorPos = getFusionStrategy().getOpAnchorPos(); SetVector &opInitArgs = getGroupOpInitArgs()[groupId]; + DenseSet visited; // find the next anchor arguments std::queue tmpQ(nextOperations); - DenseMap nextOperandArgsMap; + DenseMap nextOperandArgsMap, nextArgsOperandMap; + while (!tmpQ.empty()) { Operation *cur = tmpQ.front(); tmpQ.pop(); @@ -900,15 +964,15 @@ void ForLoopGenerator::getInitArgsToNextAnchor( opAnchorPos[cur] > anchorIdx) { int loopStateIdx = currentLoopStateIdxMap[originalOperandLoopArgsMap[x]]; - nextAnchorArgs.emplace_back(loopState[loopStateIdx]); - nextOperandArgsMap[x] = loopState[loopStateIdx]; - nextAnchorArgsIdxMap[loopState[loopStateIdx]] = - nextAnchorArgs.size() - 1; + updateCurrentArgsStatus(loopState, loopStateIdx, nextAnchorArgs, x, + nextAnchorArgsIdxMap, nextOperandArgsMap, + nextArgsOperandMap); visited.insert(x); } } } originalOperandLoopArgsMap = nextOperandArgsMap; + loopArgsOriginalOperandMap = nextArgsOperandMap; } void ForLoopGenerator::getOperationInCurrentAnchor( @@ -928,7 +992,7 @@ void ForLoopGenerator::getOperationInCurrentAnchor( void ForLoopGenerator::replaceOperationsWithForLoopResult( IRRewriter &rewrite, const ValueRange &forResults, const Block *forBlock, const llvm::SmallVector &nextAnchorResults, - const std::queue movingOperations, + const std::queue &movingOperations, DenseMap &forResultOrignalResultMap) { auto tmpQ(movingOperations); DenseSet operationOperands; @@ -961,7 +1025,8 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( SmallVector &nextAnchorArgs, std::queue &candidateQueue, std::queue &movedQueue, - DenseMap &originalOperandLoopArgsMap) { + DenseMap &originalOperandLoopArgsMap, + DenseMap &LoopArgsoriginalOperandMap) { // 1. get operations in current anchor position std::queue movingOperation; @@ -970,7 +1035,8 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( // 2. get next anchor args getInitArgsToNextAnchor(anchorIdx, groupIdx, candidateQueue, loopState, currentLoopStateIdxMap, nextAnchorArgsIdxMap, - nextAnchorArgs, originalOperandLoopArgsMap); + nextAnchorArgs, originalOperandLoopArgsMap, + LoopArgsoriginalOperandMap); // 3. rewrite operation as vectorize IR rewriteOperationAsVectorize(b, groupIdx, &movingOperation); @@ -1051,17 +1117,19 @@ void ForLoopGenerator::generateLoopResults( nextAnchorResultsIdxMap[result] = nextAnchorResults.size() - 1; } forResultOrignalResultMap = std::move(currentResultMap); - - maybeYieldValue(b, loc, nextAnchorResults); } scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, - const int anchorIdx, DenseMap ¤tLoopStateIdxMap, - const ValueRange &initArgs, SmallVector &nextAnchorResults, - DenseMap &nextAnchorResultsIdxMap, - SmallVector &inductionVars, - DenseMap &forResultOrignalResultMap) { + const int anchorIdx, llvm::DenseMap ¤tLoopStateIdxMap, + const ValueRange &initArgs, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap, + llvm::SmallVector &nextAnchorResults, + llvm::DenseMap &nextAnchorResultsIdxMap, + llvm::SmallVector &inductionVars, + DenseMap &forResultOrignalResultMap, + DenseMap &originalResultForResultMap) { MultiReductionCanonicalizer rdCanonicalizer = getMultiRdCanonicalizers()[groupIdx]; @@ -1076,7 +1144,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); bool lastDimReduction = rdCanonicalizer.hasLastDimReduction(); VectorType vectorType = rdCanonicalizer.getSourceType(); - const int loopStep = getDataTypeValidSteps(vectorType); + const int loopStep = getFusionStrategy().getGroupMaxSteps()[groupIdx]; IRRewriter rewriterOfFunc(func); @@ -1098,34 +1166,57 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( DenseMap nextAnchorArgsIdxMap; SmallVector nextAnchorArgs; std::queue movedOperation; - DenseMap operandArgsMap; - movePreOpToCurrentAnchor(anchorIdx, groupIdx, b, inductionVars, - loopState, currentLoopStateIdxMap, - nextAnchorArgsIdxMap, nextAnchorArgs, - opQueue, movedOperation, operandArgsMap); + DenseMap originalArgsMap, argsOriginalMap; + movePreOpToCurrentAnchor( + anchorIdx, groupIdx, b, inductionVars, loopState, + currentLoopStateIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, + opQueue, movedOperation, originalArgsMap, argsOriginalMap); + + // replace reduction init args + if (originalOperandLoopArgsMap.contains(multireductionOp.getAcc())) { + size_t accValIdx = currentLoopStateIdxMap + [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; + updateCurrentArgsStatus( + loopState, accValIdx, nextAnchorArgs, multireductionOp.getAcc(), + nextAnchorArgsIdxMap, originalArgsMap, argsOriginalMap); + } // 2. generate next for loop scf::ForOp nxtFor = reductionAxisGenerateForLoop( b, groupIdx, reductionIdx + 1, anchorIdx + 1, - nextAnchorArgsIdxMap, nextAnchorArgs, nextAnchorResults, - nextAnchorResultsIdxMap, inductionVars, - forResultOrignalResultMap); + nextAnchorArgsIdxMap, nextAnchorArgs, originalArgsMap, + argsOriginalMap, nextAnchorResults, nextAnchorResultsIdxMap, + inductionVars, forResultOrignalResultMap, + originalResultForResultMap); // 3. move postOp to current body movePostOpToCurrentAnchor( - rewriterOfFunc, anchorIdx, groupIdx, nxtFor->getResults(), - nxtFor->getBlock(), opQueue, movedOperation, inductionVars, - currentLoopStateIdxMap, loopState, nextAnchorResults, - forResultOrignalResultMap); - - // 4. rewrite operations as vectorized IR - rewriteOperationAsVectorize(b, groupIdx, &movedOperation); + b, anchorIdx, groupIdx, nxtFor->getResults(), b.getBlock(), + opQueue, movedOperation, inductionVars, currentLoopStateIdxMap, + loopState, nextAnchorResults, forResultOrignalResultMap); - // 5. generate loop results + // 4. generate loop results generateLoopResults(b, loc, anchorIdx, groupIdx, nextAnchorResults, nextAnchorResultsIdxMap, nxtFor->getResults(), movedOperation, forResultOrignalResultMap); + // reduction must return acc + if (originalResultForResultMap.contains( + multireductionOp->getResults()[0])) { + Value originalValue = + originalResultForResultMap[multireductionOp->getResults()[0]]; + size_t retIdx = + nextAnchorArgsIdxMap[forResultOrignalResultMap[originalValue]]; + Value forRes = nxtFor->getResults()[retIdx]; + + nextAnchorResults.emplace_back(forRes); + nextAnchorResultsIdxMap[forRes] = nextAnchorResults.size() - 1; + forResultOrignalResultMap[forRes] = originalValue; + originalResultForResultMap[originalValue] = forRes; + } + + maybeYieldValue(b, loc, nextAnchorResults); + } else if (reductionIdx == reductionAxis.size() - 1) { std::queue movingOperation; @@ -1145,20 +1236,23 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( } break; } + rewriteOperationAsVectorize(b, groupIdx, &movingOperation); moveOperationsToCurrentForBody(groupIdx, b, inductionVars, currentLoopStateIdxMap, loopState, movingOperation); - auto reductionResult = makeArithReduction( + int accValIdx = currentLoopStateIdxMap + [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; + + Value reductionResult = makeArithReduction( b, loc, multireductionOp.getKind(), multireductionOp.getSource(), - loopState.back()); + loopState[accValIdx]); movePostOpToCurrentAnchor( - rewriterOfFunc, anchorIdx, groupIdx, ValueRange(), - reductionResult.getParentBlock(), opQueue, movingOperation, - inductionVars, currentLoopStateIdxMap, loopState, + b, anchorIdx, groupIdx, ValueRange(), b.getBlock(), opQueue, + movingOperation, inductionVars, currentLoopStateIdxMap, loopState, nextAnchorResults, forResultOrignalResultMap); nextAnchorResults.clear(); @@ -1166,8 +1260,11 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( nextAnchorResultsIdxMap[reductionResult] = 0; forResultOrignalResultMap[reductionResult] = multireductionOp->getResults()[0]; + originalResultForResultMap[multireductionOp->getResults()[0]] = + reductionResult; getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, nextAnchorResults, forResultOrignalResultMap); + maybeYieldValue(b, loc, nextAnchorResults); } }); @@ -1184,6 +1281,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( DenseMap &nextAnchorResultsIdxMap, SmallVector &inductionVars, DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap, DenseMap &forResultOrignalResultMap) { auto &rdCanonicalizer = getMultiRdCanonicalizers()[groupIdx]; vector::MultiDimReductionOp &multiReductionOp = @@ -1199,7 +1297,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( // last dim reduction need to a generate dim=16 loop for fused with pre-op int dimSize = 0; if (parallelIdx == parallelAxis.size()) { - dimSize = getDataTypeMAXSIMDLength(vectorType); + dimSize = getFusionStrategy().getGroupMaxSteps()[groupIdx]; } else { dimSize = vectorType.getShape()[parallelAxis[parallelIdx]]; } @@ -1222,68 +1320,6 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( std::queue &opQueue = opGroups[opIndex]; Value multiReductionAcc = multiReductionOp.getAcc(); - // if (parallelIdx == parallelAxis.size() - 1) { - // four kinds of group operations - // If fused a operation, it means multirection must just - // constains last dim to do the reduction. - // 1. just multireduction - // two cases: - // 1. constaints last dims - // for ... parallel axis: - // transfer_read from accSource tensor - // arith.constant : vector<16xf32> - // for ... 16: - // for ... reduction axis: - // add - // scalar = reduction vector - // scalar insert into accVector - // transfer_write accVector into emtpy tensor - // 2. not last dims - // for ... generate axis: - // transfer_read from accSource tensor - // transfer_read from source tensor - // accVector = add - // transfer_write accVector into emtpy tensor - // 2. prev-op + multireduction - // In this case, there will be no tensor.empty + transfer_read - // operation, but the multireduction should write in an empty - // tensor - // for ... parallel axis: - // accVector and related accVector operation should be here - // extract from accVector scalar - // airth.constant : vector<16xf32> - // for ... reduction axis: - // prevop source op - // add - // scalar = reduction vector - // scalar insert into accVector - // transfer_write accVector into empty tensor - // - // 3. post-op + multireduction - // for ... parallel axis: - // transferread from accSource tensor - // arith.constant : vector<16xf32> - // for ... reduction axis: - // add - // postOp - // post Op transferWrite emtpy tensor - // scalar = reduction vector - // scalar insert into accVector - // transfer_write accVector into emtpy tensor - // 4. prev-op + multireduction + post-op - // for ... parallel axis: - // accVector operation - // extract from accVector a scalar - // arith.constant : vector<16xf32> - // for ... reduction axis: - // prev-op source op and related source operation - // add - // postOp - // post Op transferWrite emtpy tensor - // scalar = reduction vector - // scalar insert into accVector - // transfer_write accVector into emtpy tensor - if (parallelIdx < parallelAxis.size()) { // 1. move pre-Op to current body DenseMap nextAnchorArgsIdxMap; @@ -1292,9 +1328,17 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( movePreOpToCurrentAnchor( parallelIdx, groupIdx, b, inductionVars, loopState, currentLoopStateIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, - opQueue, movedQueue, originalOperandLoopArgsMap); + opQueue, movedQueue, originalOperandLoopArgsMap, + loopArgsOriginalOperandMap); if (parallelIdx == parallelAxis.size() - 1) { + // Ensure accumalate expression in this parallel anchor position. + // If it not appear in current anchor, we must move it in here. + // 1. delete it in operation queue + // 2. move it in current movedqueue + DenseMap> srcOpCanoniclizedMap; + DenseSet argsSet(nextAnchorArgs.begin(), + nextAnchorArgs.end()); std::queue checkAccQueue(movedQueue); Value accInitVal; while (!checkAccQueue.empty()) { @@ -1311,10 +1355,19 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( if (ok) break; } - assert(accInitVal && " Can't find accInit Value"); - nextAnchorArgs.emplace_back(accInitVal); - nextAnchorArgsIdxMap[accInitVal] = nextAnchorArgs.size() - 1; - originalOperandLoopArgsMap[multiReductionAcc] = accInitVal; + if (accInitVal) { + if (!argsSet.contains(accInitVal)) { + nextAnchorArgs.emplace_back(accInitVal); + nextAnchorArgsIdxMap[accInitVal] = nextAnchorArgs.size() - 1; + loopArgsOriginalOperandMap[accInitVal] = multiReductionAcc; + originalOperandLoopArgsMap[multiReductionAcc] = accInitVal; + } + + } else { + llvm::llvm_unreachable_internal( + "Wrong accumualte source value. Because " + "acc value must appear in here."); + } } // 2. generate next for loop @@ -1322,7 +1375,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( b, groupIdx, parallelIdx + 1, nextAnchorArgsIdxMap, nextAnchorArgs, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, - forResultOrignalResultMap); + loopArgsOriginalOperandMap, forResultOrignalResultMap); // 3. move postOp to current body movePostOpToCurrentAnchor( @@ -1335,35 +1388,46 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( generateLoopResults(b, loc, parallelIdx, groupIdx, nextAnchorResults, nextAnchorResultsIdxMap, nxtFor->getResults(), movedQueue, forResultOrignalResultMap); + maybeYieldValue(b, loc, nextAnchorResults); } else if (parallelIdx == parallelAxis.size()) { // get accumualte value Attribute initValueAttr; getReductionInitAttr(multiReductionOp, initValueAttr); + auto accVal = b.create( - loc, DenseElementsAttr::get(getVectorzedType(multiReductionOp), - {initValueAttr})); + loc, DenseElementsAttr::get( + getVectorzedType(multiReductionOp, dimSize), + {initValueAttr})); - DenseMap localAnchorResultsIdxMap; + DenseMap localAnchorArgsIdxMap; + DenseMap localOriginalOperandLoopArgsMap, + localLoopArgsOriginalOperandMap; SmallVector argsArray; argsArray.emplace_back(accVal); - localAnchorResultsIdxMap[accVal] = 0; + localAnchorArgsIdxMap[accVal] = 0; size_t accLoopStateIdx = currentLoopStateIdxMap [originalOperandLoopArgsMap[multiReductionAcc]]; + localLoopArgsOriginalOperandMap[accVal] = multiReductionAcc; + localOriginalOperandLoopArgsMap[multiReductionAcc] = accVal; for (auto [idx, x] : llvm::enumerate(loopState)) { if (idx == accLoopStateIdx) { continue; } argsArray.emplace_back(x); - localAnchorResultsIdxMap[accVal] = argsArray.size() - 1; + localAnchorArgsIdxMap[x] = argsArray.size() - 1; + Value originalValue = loopArgsOriginalOperandMap[initArgs[idx]]; + localOriginalOperandLoopArgsMap[originalValue] = x; + localLoopArgsOriginalOperandMap[x] = originalValue; } - ValueRange reductionInitArgs(argsArray); + DenseMap originalResultForResultMap; auto nxtFor = reductionAxisGenerateForLoop( - b, groupIdx, 0, parallelIdx, localAnchorResultsIdxMap, - reductionInitArgs, nextAnchorResults, nextAnchorResultsIdxMap, - inductionVars, forResultOrignalResultMap); + b, groupIdx, 0, parallelIdx, localAnchorArgsIdxMap, argsArray, + localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap, + nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, + forResultOrignalResultMap, originalResultForResultMap); // insert accumulate value to original vector auto accRes = nxtFor->getResults()[0]; @@ -1371,7 +1435,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( Operation *reductionOp = b.create( loc, multiReductionOp.getKind(), accRes); auto insertOp = b.create( - loc, reductionOp->getResult(0), accVal, iv); + loc, reductionOp->getResult(0), loopState[accLoopStateIdx], iv); // generate loop result SmallVector currentAnchorResults; @@ -1408,6 +1472,38 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( scf::ForOp ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; + auto multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; + std::queue &prevOps = rdCanonicalizer.getPrevOps(); + std::queue &postOps = rdCanonicalizer.getPostOps(); + std::queue &accRelatedOps = rdCanonicalizer.getAccRelatedOps(); + std::queue &sourceRelatedOps = + rdCanonicalizer.getSourceRelatedOps(); + + std::queue &opQueue = getFusionStrategy().getOpGroups()[grpIdx]; + auto copyOpQueue(opQueue); + getPrevOps(prevOps, copyOpQueue, multiReductionOp); + getPostOps(postOps, copyOpQueue, multiReductionOp); + classifyAccRelatedOps(accRelatedOps, sourceRelatedOps, + multiReductionOp.getAcc().getDefiningOp(), prevOps); + // move acc related operation to operation first + std::queue rectifyQueue; + DenseSet pushedSet; + auto moveOperation = [&](std::queue &from, + std::queue &to) { + while (!from.empty()) { + auto cur = from.front(); + from.pop(); + if (pushedSet.contains(cur)) { + continue; + } + to.push(cur); + pushedSet.insert(cur); + } + }; + moveOperation(accRelatedOps, rectifyQueue); + moveOperation(opQueue, rectifyQueue); + opQueue = rectifyQueue; + // get current loop init args SetVector &grpArgs = getGroupOpInitArgs()[grpIdx]; SmallVector forLoopArgs(grpArgs.begin(), grpArgs.end()); @@ -1415,7 +1511,8 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { DenseMap currentLoopStateIdxMap; DenseMap nextAnchorResultsIdxMap; // map original operation operand with loop args - DenseMap originalOperandLoopArgsMap, forResultOrignalResultMap; + DenseMap originalOperandLoopArgsMap, loopArgsOriginalOperandMap, + forResultOrignalResultMap; for (auto [idx, val] : llvm::enumerate(initArgs)) { currentLoopStateIdxMap[val] = idx; originalOperandLoopArgsMap[val] = val; @@ -1429,7 +1526,7 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { scf::ForOp forOp = parallelAxisGenerateForLoop( opBuilder, grpIdx, 0, currentLoopStateIdxMap, initArgs, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, - forResultOrignalResultMap); + loopArgsOriginalOperandMap, forResultOrignalResultMap); auto replaceIfFn = [&](OpOperand &use) { return use.getOwner()->getBlock() == forOp->getBlock(); @@ -1442,7 +1539,7 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { } rewriter.replaceOp(getMultiRdCanonicalizers()[grpIdx].getCandidateOps()[0], forOp); - + forOp->dump(); return forOp; } @@ -1534,19 +1631,19 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { while (!tempQ.empty()) { auto op = tempQ.front(); tempQ.pop(); - if (mlir::isa(op)) { + if (isa(op)) { getMultiRdCanonicalizers().back().getCandidateOps().emplace_back( - mlir::dyn_cast(op)); + cast(op)); getMultiRdCanonicalizers().back().prepareSpecialOperationInfo(); - } else if (mlir::isa(op)) { + } else if (isa(op)) { getBroadcastCanonicalizers().back().getCandidateOps().emplace_back( - mlir::dyn_cast(op)); - } else if (mlir::isa(op)) { + cast(op)); + } else if (isa(op)) { getTransposeCanonicalizers().back().getCandidateOps().emplace_back( - mlir::dyn_cast(op)); - } else if (mlir::isa(op)) { + cast(op)); + } else if (isa(op)) { getShapeCastCanonicalizers().back().getCandidateOps().emplace_back( - mlir::dyn_cast(op)); + cast(op)); } } } @@ -1619,6 +1716,7 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. analysisGroupOperaion(); + // printGroupOps(fusionStrategy.getOpGroups()); // Speical Operation Canonicalization canonicalizeSpecialOperation(); @@ -2152,37 +2250,32 @@ void ForLoopGenerator::createNewConstantOp( auto &opPermuationMap = getOpPermuationMap(); IRRewriter srcWriter(srcOp); auto newOperandType = getVectorzedType(mlir::cast(srcOp)); - auto srcConstantOp = mlir::dyn_cast(srcOp); + auto srcConstantOp = dyn_cast(srcOp); Operation *newConstantOp; if (mlir::isa(srcConstantOp.getValue())) { auto valueType = mlir::dyn_cast(srcConstantOp.getValue()); if (valueType.isSplat()) { - if (mlir::isa(valueType.getElementType())) { - newConstantOp = srcWriter.create( - srcOp->getLoc(), - FloatAttr::get(newOperandType, valueType.getSplatValue())); - } else { - newConstantOp = srcWriter.create( - srcOp->getLoc(), - IntegerAttr::get(newOperandType, valueType.getSplatValue())); + FailureOr res = createArithSplatConstantOp( + srcWriter, srcOp->getLoc(), valueType, newOperandType); + if (failed(res)) { + llvm::llvm_unreachable_internal("Wrong to create constant op."); } + newConstantOp = res.value().getDefiningOp(); } else { - assert(0 && "Not support non-splat constant value."); + newConstantOp = srcWriter.create( + srcOp->getLoc(), srcConstantOp.getValue()); } - } else { - newConstantOp = srcWriter.create( - srcOp->getLoc(), srcConstantOp.getValue()); - } - newConstantOp->getResult(0).setType(newOperandType); - transferWriteOp->setOperand(0, newConstantOp->getResult(0)); - opPermuationMap.insert( - {mlir::cast(srcOp), transferWriteOp->getPermutationMap()}); - setOpVectorizationPermutationMap( - mlir::cast(srcOp), srcWriter, - mlir::dyn_cast( - transferWriteOp->getResults()[0].getType()), - transferWriteOp->getPermutationMap()); + newConstantOp->getResult(0).setType(newOperandType); + transferWriteOp->setOperand(0, newConstantOp->getResult(0)); + opPermuationMap.insert( + {mlir::cast(srcOp), transferWriteOp->getPermutationMap()}); + setOpVectorizationPermutationMap( + mlir::cast(srcOp), srcWriter, + mlir::dyn_cast( + transferWriteOp->getResults()[0].getType()), + transferWriteOp->getPermutationMap()); + } } /// Rewrite the operations in the group to vectorized form. @@ -2282,7 +2375,7 @@ mlir::FailureOr getOperationOperateTensor(Operation *op) { }); } -void VectorOperationAnalysizer::updateOpOperandResultInGroups( +void CanonicalizerCommonUsedData::updateOpOperandResultInGroups( size_t opGid, Operation *op, Value &init, const Value &result) { std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[opGid]); std::queue newOpQueue; @@ -2315,7 +2408,7 @@ void VectorOperationAnalysizer::updateOpOperandResultInGroups( void VectorFusionStrategy::run() { classifyOperations(); } -void VectorOperationAnalysizer::generateEmptyTensorAndWrite( +void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( Operation *sourceOp, DenseMap> &srcOpCanoniclizedMap, size_t anchorPos, ReturnTypeKind retKind) { @@ -2326,19 +2419,19 @@ void VectorOperationAnalysizer::generateEmptyTensorAndWrite( &groupOpResults = getGroupOpResults(); size_t sourceOpGid = opGroupIndexMap[sourceOp]; - auto [resultTensor, result] = canonicalizeSourceOperation(sourceOp); - auto writeOp = result.getDefiningOp(); - srcOpCanoniclizedMap.insert({sourceOp, {resultTensor, result}}); - updateOpOperandResultInGroups(sourceOpGid, sourceOp, resultTensor, result); - groupOpInitArgs[sourceOpGid].insert(resultTensor); - groupOpResults[sourceOpGid].insert({result, {retKind, anchorPos}}); - groupOpResults[sourceOpGid].back().first.dump(); + auto [tsr, writeOpresult] = canonicalizeSourceOperation(sourceOp); + auto writeOp = writeOpresult.getDefiningOp(); + srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); + updateOpOperandResultInGroups(sourceOpGid, sourceOp, tsr, writeOpresult); + groupOpInitArgs[sourceOpGid].insert(tsr); + groupOpResults[sourceOpGid].insert({writeOpresult, {retKind, anchorPos}}); + // write opeartion anchor pos is same with current operation getFusionStrategy().getOpAnchorPos()[writeOp] = cast(writeOp).getVectorType().getRank() - 1; getOpPermuationMap()[writeOp] = writeOp.getPermutationMap(); } -void VectorOperationAnalysizer::analysisEmptyGroupAndMaxSteps() { +void VectorOperationAnalyzer::analysisEmptyGroupAndMaxSteps() { auto &groupOpResults = getGroupOpResults(); auto &opGroups = getFusionStrategy().getOpGroups(); @@ -2372,9 +2465,25 @@ void VectorOperationAnalysizer::analysisEmptyGroupAndMaxSteps() { } } +void VectorOperationAnalyzer::specialOperationAnchorRectify() { + auto &opGroups = getFusionStrategy().getOpGroups(); + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + std::queue tmpQueue(grp); + while (!tmpQueue.empty()) { + auto op = tmpQueue.front(); + tmpQueue.pop(); + if (isa(op)) { + auto accSourceOp = op->getOperand(1).getDefiningOp(); + getFusionStrategy().getOpAnchorPos()[accSourceOp] = + getOperationVectorType(accSourceOp)->getRank() - 1; + } + } + } +} + // analysis operation result of current group whether needed by other // operation which out of current group -void VectorOperationAnalysizer::analysisGroupOperationResults() { +void VectorOperationAnalyzer::analysisGroupOperationResults() { DenseMap> srcOpCanoniclizedMap; DenseSet movedOperationSet; DenseMap &opGroupIndexMap = @@ -2399,25 +2508,23 @@ void VectorOperationAnalysizer::analysisGroupOperationResults() { if (prevRtKind != rtKind) { groupOpResults[sourceOpGid][sourceResult] = std::make_pair(ReturnTypeKind::RT_Both, srcOpAnchor); - } else { + } else if (rtKind == ReturnTypeKind::RT_InGroup) { groupOpResults[sourceOpGid][sourceResult] = std::make_pair(rtKind, srcOpAnchor); } - groupOpResults[sourceOpGid].back().first.dump(); }; func.walk([&](Operation *op) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { Operation *sourceOp = opd.getDefiningOp(); - if (opGroupIndexMap.contains(sourceOp)) { auto sourceOpGid = opGroupIndexMap[sourceOp]; bool notInSameGroup = opGroupIndexMap.contains(op) && sourceOpGid != opGroupIndexMap[op]; bool outOfGroup = !opGroupIndexMap.contains(op); - // Different anchor in same group and source operation is in inner loop, - // we need to get source operation's result + // Different anchor in same group and source operation is in inner + // loop, we need to get source operation's result bool inSameGroupNeedReturn = !notInSameGroup and OpAnchorPos[sourceOp] > OpAnchorPos[op]; ReturnTypeKind rtKind = inSameGroupNeedReturn @@ -2463,6 +2570,18 @@ void VectorOperationAnalysizer::analysisGroupOperationResults() { } } } + } else if (isa_and_nonnull(sourceOp)) { + auto constantOp = cast(sourceOp); + IRRewriter rewriter(constantOp); + if (mlir::isa(constantOp.getValue())) { + if (!srcOpCanoniclizedMap.contains(sourceOp)) { + auto [tsr, writeOpresult] = canonicalizeSourceOperation(sourceOp); + srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); + } + auto opInit = canonicalizeCurrentOperation( + op, srcOpCanoniclizedMap[sourceOp].second, idx); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + } } } if (mlir::isa(op) && !movedOperationSet.contains(op)) { @@ -2472,11 +2591,12 @@ void VectorOperationAnalysizer::analysisGroupOperationResults() { } }); analysisEmptyGroupAndMaxSteps(); + specialOperationAnchorRectify(); #undef RESULT_RETURN_TYPE LDBG("Complete analysis group operation results\n"); } -void VectorOperationAnalysizer::analysisGroupOperaion() { +void VectorOperationAnalyzer::analysisGroupOperaion() { // Results analysisGroupOperationResults(); } @@ -2578,7 +2698,8 @@ struct CPUPhysicalRegisterPass LDBG("Not support operation appears in current function."); return; } - // canonicalize vector operation, default use vector-based fusion strategy. + // canonicalize vector operation, default use vector-based fusion + // strategy. HardWareInfo hwInfo; // default has avx512f instructions // hwInfo.favx512f = false; diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 6e0561ac3..60d572d2d 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -141,77 +141,77 @@ func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf return %3 : tensor<8192x16384xf32> } -// #map = affine_map<(d0) -> (d0 * 64)> -// #map1 = affine_map<(d0) -> (d0 * 128)> -// #map2 = affine_map<(d0) -> (d0 floordiv 16)> -// #map3 = affine_map<(d0) -> (d0 floordiv 32)> -// #map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> -// #map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> -// func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { -// %c32 = arith.constant 32 : index -// %c512 = arith.constant 512 : index -// %c128 = arith.constant 128 : index -// %c64 = arith.constant 64 : index -// %c0 = arith.constant 0 : index -// %cst = arith.constant 0.000000e+00 : bf16 -// %0 = tensor.empty() : tensor<128x256xbf16> -// %1 = tensor.empty() : tensor<512x256xbf16> -// %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { -// %3 = affine.apply #map(%arg3) -// %4 = affine.apply #map1(%arg4) -// %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> -// %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> -// %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> -// %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> -// %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> -// %11 = affine.apply #map2(%arg28) -// %12 = affine.apply #map3(%arg24) -// %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> -// %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> -// %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> -// %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> -// %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> -// %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %15 = affine.apply #map4(%arg12, %arg24, %arg4) -// %16 = affine.apply #map5(%arg8, %arg20, %arg3) -// %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> -// %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> -// %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] -// %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.forall.in_parallel { -// tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// } -// } -// return %2#2 : tensor<128x256xbf16> -// } +#map = affine_map<(d0) -> (d0 * 64)> +#map1 = affine_map<(d0) -> (d0 * 128)> +#map2 = affine_map<(d0) -> (d0 floordiv 16)> +#map3 = affine_map<(d0) -> (d0 floordiv 32)> +#map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> +#map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> + func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = tensor.empty() : tensor<512x256xbf16> + %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { + %3 = affine.apply #map(%arg3) + %4 = affine.apply #map1(%arg4) + %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> + %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> + %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> + %11 = affine.apply #map2(%arg28) + %12 = affine.apply #map3(%arg24) + %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> + %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> + %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> + %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> + %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> + %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %15 = affine.apply #map4(%arg12, %arg24, %arg4) + %16 = affine.apply #map5(%arg8, %arg20, %arg3) + %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> + %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> + %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] + %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + } + } + return %2#2 : tensor<128x256xbf16> + } From ab7c4d0825ea469d779e6c3d8cf61c7ba50ad04d Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 23 Jul 2024 15:25:41 +0800 Subject: [PATCH 18/66] add 16x16 transpose kernel --- include/gc/Transforms/TilingVector.h | 21 +- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 259 +++++++++++++++++- 2 files changed, 278 insertions(+), 2 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index 38bc51d3c..9ac9f4a5e 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -18,6 +18,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/ExecutionEngine/Float16bits.h" #include "mlir/IR/AffineMap.h" @@ -246,16 +247,22 @@ class BroadcastCanonicalizer class TransposeCanonicalizer : public SpecialOperationCanonicalizer { private: + size_t firstTpIdx = 0, secondTpIdx = 0; + public: TransposeCanonicalizer( const llvm::SmallVector &candidateTpOps) : SpecialOperationCanonicalizer( candidateTpOps, SpecialOperationKind::OP_Transpose){}; virtual ~TransposeCanonicalizer() {} - void prepareSpecialOperationInfo() override {} + void prepareSpecialOperationInfo() override; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose; } + + size_t getFirstTpIdx() { return firstTpIdx; } + size_t getSecondTpIdx() { return secondTpIdx; } + bool isTwoDTranspose(); }; class ShapeCastCanonicalizer @@ -514,6 +521,18 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const llvm::SmallVector ¶llelAxis, llvm::SmallVector &inductionVars, bool lastDimReduction, MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel); + + /// transpose operation related + scf::ForOp generateTransposeForLoop(const size_t groupId); + scf::ForOp generateTransposeForLoopWithLastDim( + OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, + const int tpSteps, const Location &loc, SmallVector &inductionVars, + const ValueRange &iterArgs); + + scf::ForOp generateScalarDataMovement( + OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, + const Location &loc, SmallVector &inductionVars, + const ValueRange &iterArgs, DenseMap &tpAxisMap); }; class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index 1598fd7f9..b4afbea35 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -1430,6 +1430,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( forResultOrignalResultMap, originalResultForResultMap); // insert accumulate value to original vector + // TODO: fix first accumualte idx use map auto accRes = nxtFor->getResults()[0]; Operation *reductionOp = b.create( @@ -1469,6 +1470,79 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( }); } +scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( + OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, + const int tpSteps, const Location &loc, SmallVector &inductionVars, + const ValueRange &iterArgs) { + auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; + vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; + VectorType vtType = tpOp.getVector().getType(); + size_t rank = vtType.getRank(); + + auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); + bool isTransposeDim = forDimIdx == tpCanonicalizer.getFirstTpIdx() or + forDimIdx == tpCanonicalizer.getSecondTpIdx(); + auto forSteps = + makeIndexArithConstantOp(opBuilder, loc, isTransposeDim ? tpSteps : 1); + auto numIter = + makeIndexArithConstantOp(opBuilder, loc, vtType.getShape()[forDimIdx]); + VectorType kernelType = + VectorType::get({tpSteps, tpSteps}, vtType.getElementType()); + // generate transpose for loop + return opBuilder.create( + loc, zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + + // inner most body of the loop + if (forDimIdx == rank - 1) { + // transfer read from source tensor + Value source = tpOp->getOperand(0); + auto readSourceOp = + cast(source.getDefiningOp()); + vector::TransferWriteOp successorWriteOp; + for (Operation *x : tpOp->getUsers()) { + if (isa(x)) { + successorWriteOp = cast(x); + } + } + auto padValue = b.create( + loc, b.getZeroAttr(vtType.getElementType())); + SmallVector inBoundsVal(2, true); + inBoundsVal[0] = !ShapedType::isDynamic( + vtType.getShape()[tpCanonicalizer.getFirstTpIdx()]); + inBoundsVal[1] = !ShapedType::isDynamic( + vtType.getShape()[tpCanonicalizer.getSecondTpIdx()]); + + auto transferReadOp = b.create( + loc, + /*vectorType=*/kernelType, + /*source=*/readSourceOp.getSource(), + /*indices=*/inductionVars, + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); + SmallVector perm{1, 0}; + auto transposeOp = b.create( + loc, transferReadOp->getResults()[0], perm); + SmallVector writeVars(inductionVars.begin(), + inductionVars.end()); + writeVars[tpCanonicalizer.getSecondTpIdx()] = + inductionVars[tpCanonicalizer.getFirstTpIdx()]; + writeVars[tpCanonicalizer.getFirstTpIdx()] = + inductionVars[tpCanonicalizer.getSecondTpIdx()]; + auto writeOp = b.create( + loc, transposeOp->getResults()[0], + successorWriteOp->getOperands()[1], writeVars, inBoundsVal); + maybeYieldValue(b, loc, writeOp->getResults()); + } else { + // outter loop + auto nxtFor = generateTransposeForLoopWithLastDim( + b, grpIdx, forDimIdx + 1, tpSteps, loc, inductionVars, loopState); + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); +} + scf::ForOp ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; @@ -1539,6 +1613,144 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { } rewriter.replaceOp(getMultiRdCanonicalizers()[grpIdx].getCandidateOps()[0], forOp); + + return forOp; +} + +// generate simple data movement for loop +scf::ForOp ForLoopGenerator::generateScalarDataMovement( + OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, + const Location &loc, SmallVector &inductionVars, + const ValueRange &iterArgs, DenseMap &tpAxisMap) { + auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; + vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; + VectorType vtType = tpOp.getVector().getType(); + size_t rank = vtType.getRank(); + + auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); + auto forSteps = makeIndexArithConstantOp(opBuilder, loc, 1); + auto numIter = + makeIndexArithConstantOp(opBuilder, loc, vtType.getShape()[forDimIdx]); + VectorType kernelType = VectorType::get({1}, vtType.getElementType()); + // generate transpose for loop + return opBuilder.create( + loc, zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + + // inner most body of the loop + if (forDimIdx == rank - 1) { + // transfer read from source tensor + Value source = tpOp->getOperand(0); + auto readSourceOp = + cast(source.getDefiningOp()); + vector::TransferWriteOp successorWriteOp; + for (Operation *x : tpOp->getUsers()) { + if (isa(x)) { + successorWriteOp = cast(x); + } + } + auto padValue = b.create( + loc, b.getZeroAttr(vtType.getElementType())); + SmallVector inBoundsVal(1, true); + + auto transferReadOp = b.create( + loc, + /*vectorType=*/kernelType, + /*source=*/readSourceOp.getSource(), + /*indices=*/inductionVars, + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); + SmallVector writeVars; + size_t itrIdx = 0; + while (itrIdx < rank) { + writeVars.emplace_back(inductionVars[tpAxisMap[itrIdx]]); + itrIdx++; + } + + auto writeOp = b.create( + loc, transferReadOp->getResults()[0], + successorWriteOp->getOperands()[1], writeVars, inBoundsVal); + maybeYieldValue(b, loc, writeOp->getResults()); + } else { + // outter loop + auto nxtFor = + generateScalarDataMovement(b, grpIdx, forDimIdx + 1, loc, + inductionVars, loopState, tpAxisMap); + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); +} + +/// generate transpose for loop +scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { + + // transpose rank must bigger than 2 + TransposeCanonicalizer &tpCanonicalizer = + getTransposeCanonicalizers()[grpIdx]; + vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; + VectorType vtType = tpOp.getVector().getType(); + std::cout << " _________ check tp operation source." + << "\n"; + vtType.dump(); + tpOp->getResultTypes()[0].dump(); + size_t rank = vtType.getRank(); + if (rank < 2) { + llvm::llvm_unreachable_internal( + "Wrong transpose operation appear. It's rank must bigger than 2."); + return nullptr; + } + + // permutation contains last dim can use optimizing algorithm + ArrayRef permutation = tpOp.getPermutation(); + DenseSet permuteSet(permutation.begin(), permutation.end()); + bool isTwoDTranspose = tpCanonicalizer.isTwoDTranspose(); + const int tpStep = 16; + // currently we only support shape that is an integer multiple of tpStep + if (vtType.getShape()[tpCanonicalizer.getFirstTpIdx()] % tpStep != 0 or + vtType.getShape()[tpCanonicalizer.getSecondTpIdx()] % tpStep != 0) { + isTwoDTranspose = false; + } + OpBuilder b(tpOp); + SmallVector iterArgs; + vector::TransferWriteOp successorWriteOp; + for (Operation *x : tpOp->getUsers()) { + if (isa(x)) { + successorWriteOp = cast(x); + } + } + iterArgs.emplace_back(successorWriteOp->getOperands()[1]); + SmallVector inductionVars; + IRRewriter rewriter(func); + + if (permuteSet.contains(rank - 1) and isTwoDTranspose) { + std::cout << " can use 16x16 : " << std::endl; + scf::ForOp forOp = generateTransposeForLoopWithLastDim( + b, grpIdx, 0, tpStep, tpOp.getLoc(), inductionVars, iterArgs); + + for (Operation *x : tpOp->getUsers()) { + if (isa(x)) { + rewriter.replaceOp(x, forOp); + } + } + return forOp; + } + // findTransposeAxisMap(DenseMap & tpAxisMap); + DenseMap tpAxisMap; + size_t itrIdx = 0; + while (itrIdx < rank) { + tpAxisMap[itrIdx] = permutation[itrIdx]; + itrIdx++; + } + // scalar data movement + scf::ForOp forOp = generateScalarDataMovement( + b, grpIdx, 0, tpOp.getLoc(), inductionVars, iterArgs, tpAxisMap); + for (Operation *x : tpOp->getUsers()) { + if (isa(x)) { + rewriter.replaceOp(x, forOp); + } + } + std::cout << " scalar data movement." << std::endl; forOp->dump(); return forOp; } @@ -1600,6 +1812,44 @@ void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { hasLastDimReduction(); }; +void TransposeCanonicalizer::prepareSpecialOperationInfo() { + if (getCandidateOps().empty()) { + return; + } +} + +bool TransposeCanonicalizer::isTwoDTranspose() { + ArrayRef permutation = getCandidateOps()[0].getPermutation(); + size_t rank = permutation.size(); + int diffCount = 0; + // get the first transpose axis + size_t itrIdx = 0; + while (itrIdx < rank) { + if ((int64_t)itrIdx != permutation[itrIdx]) { + diffCount += 1; + } + itrIdx += 1; + } + itrIdx = 0; + while (itrIdx < rank) { + if (permutation[itrIdx] != (int64_t)itrIdx) { + firstTpIdx = itrIdx; + break; + } + itrIdx++; + } + itrIdx = 0; + // get the second transpose axis + while (itrIdx < rank) { + if (permutation[itrIdx] == (int64_t)firstTpIdx) { + secondTpIdx = itrIdx; + break; + } + itrIdx++; + } + return diffCount == 2; +} + template void addDummyInit(SmallVector &canonicalizer) { canonicalizer.emplace_back(T({})); }; @@ -1670,7 +1920,7 @@ void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { SmallVector &transposeOps = transposeCanonicalizers[groupId].getCandidateOps(); if (!transposeOps.empty()) { - // (void) generateTransposeForLoop(groupId); + (void)generateTransposeForLoop(groupId); } } } @@ -2707,6 +2957,13 @@ struct CPUPhysicalRegisterPass func, CanonicalizerKind::OperationsGroup, hwInfo); canonicalizer.run(); + // transpose kernel + vector::VectorTransformsOptions transposeOptions = + vector::VectorTransformsOptions(); + transposeOptions.vectorTransposeLowering = + vector::VectorTransposeLowering::Shuffle16x16; + vector::populateVectorTransposeLoweringPatterns(patterns, transposeOptions); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; From 72f1a8bf4b5cd6e46775d525f8f2b3fbd6d9b330 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 25 Jul 2024 22:34:20 +0800 Subject: [PATCH 19/66] update reduce, add shapecast, add single matmul test --- include/gc/Transforms/TilingVector.h | 36 +- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 753 ++++++++++++++---- .../gc/transforms/cpu-vetor-distribution.mlir | 247 ++++-- 3 files changed, 836 insertions(+), 200 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index 9ac9f4a5e..b65a91259 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -9,6 +9,7 @@ #define GC_PASSES_TILINGVECTOR_H #include "gc/Transforms/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -21,6 +22,7 @@ #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/ExecutionEngine/Float16bits.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" @@ -122,10 +124,8 @@ class VectorFusionStrategy : public TypeHelper { llvm::SmallDenseMap &getGroupBiggestRankVectorType() { return groupBigestRankVectorType; }; - llvm::SmallVector, 8> &getOpGroups() { - return opGroups; - } - llvm::DenseMap &getOpGroupIndexMap() { + SmallVector, 8> &getOpGroups() { return opGroups; } + DenseMap &getOpGroupIndexMap() { return opGroupIndexMap; } llvm::SmallVector &getGroupMaxSteps() { return groupMaxSteps; } @@ -187,6 +187,8 @@ class MultiReductionCanonicalizer std::queue prevOps, postOps, accRelatedOps, sourceRelatedOps; bool haslastDimReduction = false; bool isStandaloneOp = false; + /// empty reduction means that all the reduction axis is 1 + bool isEmptyReduction = true; int64_t typeRank = -1; llvm::SetVector originalOpResults; VectorType sourceType, accType; @@ -205,6 +207,8 @@ class MultiReductionCanonicalizer void getReductionAxisAndParallelAxis(); bool hasLastDimReduction(); bool getIsStandAloneOp() { return isStandaloneOp; } + bool getHasLastDimReduction() { return haslastDimReduction; } + bool getIsEmptyReduction() { return isEmptyReduction; } void initReductionAxis(); void initParallelAxis(); llvm::SmallVector &getReductionAxis() { return reductionAxis; }; @@ -278,6 +282,7 @@ class ShapeCastCanonicalizer static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_ShapeCast; } + bool isReadWriteOnLastDim(); }; enum class ReturnTypeKind { @@ -363,7 +368,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { return groupOpInitArgs; } - llvm::DenseMap &getOpPermuationMap() { + DenseMap &getOpPermuationMap() { return opPermuationMap; } @@ -391,7 +396,8 @@ class CanonicalizerCommonUsedData : public TypeHelper { Operation *sourceOp, llvm::DenseMap> &srcOpCanoniclizedMap, - size_t anchorPos, ReturnTypeKind retKind); + size_t anchorPos, ReturnTypeKind retKind, + DenseMap &visitedOperation); void updateOpOperandResultInGroups(size_t opGid, Operation *op, Value &init, const Value &result = Value()); @@ -416,12 +422,14 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { virtual ~ForLoopGenerator() {} void setGeneratorFunc(func::FuncOp &func) { this->func = func; } + void clearCurrentOperationGroup(size_t grpIdx); void generateGroupOpVectorizedIR(const int idx); void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, const std::queue *queue = nullptr); void createNewConstantOp(Operation *srcOp, - vector::TransferWriteOp *transferWriteOp); + vector::TransferWriteOp *transferWriteOp, + size_t groupSteps); // elementwise for loop mlir::FailureOr generateVectorizedForLoop(const size_t groupId, IRRewriter &rewriter, @@ -529,10 +537,17 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const int tpSteps, const Location &loc, SmallVector &inductionVars, const ValueRange &iterArgs); - scf::ForOp generateScalarDataMovement( + scf::ForOp generateTransposeScalarDataMovement( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const Location &loc, SmallVector &inductionVars, const ValueRange &iterArgs, DenseMap &tpAxisMap); + + // shapecast + scf::ForOp generateShapeCastForLoop(const size_t grpIdx); + scf::ForOp generateShapeCastReadWriteLoop( + OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, + const size_t steps, const Location &loc, + SmallVector &inductionVars, const ValueRange &iterArgs); }; class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { @@ -545,7 +560,10 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { VectorOperationAnalyzer(func::FuncOp &func) : func(func) {} void setAnalysisFunc(func::FuncOp &func) { this->func = func; } - void analysisEmptyGroupAndMaxSteps(); + /// remove the useless operation, due to it result is not require by other + // operation + void analysisEmptyGroup(); + void analysisGroupMaxSteps(); void analysisGroupOperaion(); void analysisGroupOperationResults(); void specialOperationAnchorRectify(); diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index b4afbea35..cd262a299 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -20,7 +20,8 @@ namespace { #define ARITH_CAST_OPERATIONS \ arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::BitcastOp, \ - arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp + arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp, \ + arith::TruncFOp, arith::TruncIOp // TODO: remove it in the future bool disableSpecialOp = false; @@ -56,6 +57,15 @@ void printQueue(const std::queue &opQueue) { "__________________" << std::endl; } +/// Get the index position of the first element that is true +static size_t getFirstTrueIndex(ArrayRef ararys) { + for (size_t i = 0; i < ararys.size(); i++) { + if (!ararys[i]) { + return i; + } + } + return -1; +} bool isSpecialOp(Operation *op) { return isa(op) || isa(op) || @@ -153,29 +163,42 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { } } +/// get float or integer dense attribute +/// \param [in,out] attr +template +void getConstantDenseAttr(TypedAttr &attr, VectorType type, + DenseElementsAttr denseAttr) { + using APX = std::conditional_t, + APFloat, APInt>; + attr = T::get(type, denseAttr.getSplatValue()); +} + FailureOr createArithSplatConstantOp(IRRewriter &rewriter, const Location &loc, - const ElementsAttr &valueType, - VectorType &newOperandType) { + DenseElementsAttr valueType, + VectorType newOperandType) { if (valueType.isSplat()) { - Value res; - if (mlir::isa(valueType.getElementType())) { - res = rewriter.create( - loc, - FloatAttr::get(newOperandType, valueType.getSplatValue())); + TypedAttr attr; + if (isa(newOperandType.getElementType())) { + getConstantDenseAttr(attr, newOperandType, + valueType); } else { - res = rewriter.create( - loc, - IntegerAttr::get(newOperandType, valueType.getSplatValue())); + getConstantDenseAttr(attr, newOperandType, + valueType); } - return res; + return rewriter.create(loc, attr)->getResults()[0]; } return failure(); } -mlir::FailureOr getOperationVectorType(Operation *op) { +/// get operation vector type +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationVectorType(Operation *op, + bool isPrevOp = true) { if (!op) { return failure(); } @@ -185,8 +208,8 @@ mlir::FailureOr getOperationVectorType(Operation *op) { .Case( [&](vector::TransferWriteOp transferWriteOp) -> mlir::FailureOr { - auto retType = mlir::dyn_cast( - transferWriteOp->getOperand(0).getType()); + auto retType = + dyn_cast(transferWriteOp.getOperandTypes()[0]); if (retType) { return retType; } @@ -200,21 +223,30 @@ mlir::FailureOr getOperationVectorType(Operation *op) { }) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { - return multiReductionOp.getSourceVectorType(); - }) - .Case( - [&](arith::ConstantOp constantOp) -> mlir::FailureOr { - return failure(); + if (isPrevOp) { + return cast( + multiReductionOp->getResultTypes()[0]); + } + // TODO: may need to add accumulate value vectortype + return cast(multiReductionOp.getSourceVectorType()); }) .Default([&](Operation *op) -> mlir::FailureOr { - if (!op->getResults().empty()) { - auto t = dyn_cast(op->getResultTypes().front()); - if (t) { - if (isDynamicType(t)) { - return failure(); - } - return t; + if (isPrevOp) { + if (op->getResultTypes().empty()) { + return failure(); + } + if (auto shapedType = + dyn_cast(op->getResultTypes()[0])) { + return shapedType; } + return failure(); + } + if (op->getOperandTypes().empty()) { + return failure(); + } + if (auto shapedType = + dyn_cast(op->getOperandTypes()[0])) { + return shapedType; } return failure(); }); @@ -224,10 +256,62 @@ mlir::FailureOr getOperationVectorType(Operation *op) { return ret; } +/// get operation vector type +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationMaxVectorType(Operation *op) { + if (!op) { + return failure(); + } + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + auto retType = + dyn_cast(transferWriteOp.getOperandTypes()[0]); + if (retType) { + return retType; + } + LDBG("TransferWrite Operation has wrong vector to write."); + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + return cast(multiReductionOp.getSourceVectorType()); + }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (op->getResultTypes().empty() and + op->getOperandTypes().empty()) { + return failure(); + } + if (op->getResultTypes().empty()) { + return dyn_cast(op->getOperandTypes()[0]); + } + if (op->getOperandTypes().empty()) { + return dyn_cast(op->getResultTypes()[0]); + } + auto opdType = dyn_cast(op->getOperandTypes()[0]); + auto retType = dyn_cast(op->getResultTypes()[0]); + return opdType.getRank() > retType.getRank() ? opdType : retType; + }); + if (!failed(ret) and isDynamicType(ret.value())) { + return failure(); + } + return ret; +} + VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loop_step) { // Check that the operation type can be broken // down into a loop. - auto baseType = getOperationVectorType(op); + mlir::FailureOr baseType = getOperationVectorType(op); if (failed(baseType)) { LDBG("Failed to get vector type for operation: " << *op << "\n"); assert(false && "Failed to get vector type for operation"); @@ -588,12 +672,24 @@ Operation *createTensorEmptyBefore(Operation *op) { rtType.getElementType(), dynDims); } -Value getOperationResultTensor(Operation *op) { - auto result = op->getResults()[0]; - for (auto x : result.getUsers()) { - if (mlir::isa(x)) { - return x->getOperand(1); +/// get the tensor that operation should write into +Value getOperationResultTensor( + Operation *op, DenseMap &visitedOperation) { + OpResult result = op->getResults()[0]; + for (Operation *x : result.getUsers()) { + if (!isa(x)) { + continue; + } + Value sourceTensor = x->getOperands()[1]; + Operation *srcOp = sourceTensor.getDefiningOp(); + if (!visitedOperation.contains(srcOp)) { + continue; } + size_t pos = visitedOperation[srcOp]; + if (pos > visitedOperation[op]) { + continue; + } + return sourceTensor; } LDBG("Result not write back to tensor."); @@ -679,8 +775,9 @@ Operation *CanonicalizerCommonUsedData::createTransferReadOpBefore( // canonicalizing operation as tensor empty and transfer write the operation // result into the empty tensor [[nodiscard]] std::pair -canonicalizeSourceOperation(Operation *op) { - auto resultTensor = getOperationResultTensor(op); +canonicalizeSourceOperation(Operation *op, + DenseMap &visitedOperation) { + auto resultTensor = getOperationResultTensor(op, visitedOperation); auto writeOp = createTransferWriteOpAfter(op, resultTensor); return std::make_pair(resultTensor, writeOp->getResults()[0]); } @@ -1009,6 +1106,7 @@ void ForLoopGenerator::replaceOperationsWithForLoopResult( for (auto [nxtForResult, nextLoopResult] : zip(forResults, nextAnchorResults)) { Value originalResult = forResultOrignalResultMap[nextLoopResult]; + rewrite.replaceOpUsesWithIf(originalResult.getDefiningOp(), nxtForResult, replaceIfFn); } @@ -1200,19 +1298,21 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( nextAnchorResultsIdxMap, nxtFor->getResults(), movedOperation, forResultOrignalResultMap); - // reduction must return acc + // reduction must return accumulate if (originalResultForResultMap.contains( multireductionOp->getResults()[0])) { - Value originalValue = + Value lastForResult = originalResultForResultMap[multireductionOp->getResults()[0]]; size_t retIdx = - nextAnchorArgsIdxMap[forResultOrignalResultMap[originalValue]]; + nextAnchorArgsIdxMap[forResultOrignalResultMap[lastForResult]]; Value forRes = nxtFor->getResults()[retIdx]; nextAnchorResults.emplace_back(forRes); nextAnchorResultsIdxMap[forRes] = nextAnchorResults.size() - 1; - forResultOrignalResultMap[forRes] = originalValue; - originalResultForResultMap[originalValue] = forRes; + forResultOrignalResultMap[forRes] = + multireductionOp->getResults()[0]; + originalResultForResultMap[multireductionOp->getResults()[0]] = + forRes; } maybeYieldValue(b, loc, nextAnchorResults); @@ -1242,29 +1342,41 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( moveOperationsToCurrentForBody(groupIdx, b, inductionVars, currentLoopStateIdxMap, loopState, movingOperation); + if (!rdCanonicalizer.getIsEmptyReduction()) { + int accValIdx = currentLoopStateIdxMap + [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; - int accValIdx = currentLoopStateIdxMap - [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; - - Value reductionResult = makeArithReduction( - b, loc, multireductionOp.getKind(), multireductionOp.getSource(), - loopState[accValIdx]); - - movePostOpToCurrentAnchor( - b, anchorIdx, groupIdx, ValueRange(), b.getBlock(), opQueue, - movingOperation, inductionVars, currentLoopStateIdxMap, loopState, - nextAnchorResults, forResultOrignalResultMap); - - nextAnchorResults.clear(); - nextAnchorResults.emplace_back(reductionResult); - nextAnchorResultsIdxMap[reductionResult] = 0; - forResultOrignalResultMap[reductionResult] = - multireductionOp->getResults()[0]; - originalResultForResultMap[multireductionOp->getResults()[0]] = - reductionResult; - getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, - nextAnchorResults, forResultOrignalResultMap); + Value reductionResult = makeArithReduction( + b, loc, multireductionOp.getKind(), + multireductionOp.getSource(), loopState[accValIdx]); + + movePostOpToCurrentAnchor( + b, anchorIdx, groupIdx, ValueRange(), b.getBlock(), opQueue, + movingOperation, inductionVars, currentLoopStateIdxMap, + loopState, nextAnchorResults, forResultOrignalResultMap); + + nextAnchorResults.clear(); + nextAnchorResults.emplace_back(reductionResult); + nextAnchorResultsIdxMap[reductionResult] = 0; + forResultOrignalResultMap[reductionResult] = + multireductionOp->getResults()[0]; + originalResultForResultMap[multireductionOp->getResults()[0]] = + reductionResult; + getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, + nextAnchorResults, forResultOrignalResultMap); + } else { + Value sourceVal = multireductionOp.getSource(); + nextAnchorResults.clear(); + nextAnchorResults.emplace_back(sourceVal); + nextAnchorResultsIdxMap[sourceVal] = 0; + forResultOrignalResultMap[sourceVal] = + multireductionOp->getResults()[0]; + originalResultForResultMap[multireductionOp->getResults()[0]] = + sourceVal; + getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, + nextAnchorResults, forResultOrignalResultMap); + } maybeYieldValue(b, loc, nextAnchorResults); } }); @@ -1292,7 +1404,12 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( SmallVector ¶llelAxis = rdCanonicalizer.getParallelAxis(); const Location &loc = multiReductionOp.getLoc(); Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); - Value forSteps = makeIndexArithConstantOp(opBuilder, loc, 1); + size_t grpMaxStep = getFusionStrategy().getGroupMaxSteps()[groupIdx]; + size_t actualStep = (parallelIdx == parallelAxis.size() - 1 and + !rdCanonicalizer.getHasLastDimReduction()) + ? grpMaxStep + : 1; + Value forSteps = makeIndexArithConstantOp(opBuilder, loc, actualStep); // last dim reduction need to a generate dim=16 loop for fused with pre-op int dimSize = 0; @@ -1370,12 +1487,25 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( } } + scf::ForOp nxtFor; + DenseMap originalResultForResultMap; // 2. generate next for loop - scf::ForOp nxtFor = parallelAxisGenerateForLoop( - b, groupIdx, parallelIdx + 1, nextAnchorArgsIdxMap, - nextAnchorArgs, nextAnchorResults, nextAnchorResultsIdxMap, - inductionVars, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, forResultOrignalResultMap); + if (rdCanonicalizer.hasLastDimReduction() or + parallelIdx < parallelAxis.size() - 1) { + nxtFor = parallelAxisGenerateForLoop( + b, groupIdx, parallelIdx + 1, nextAnchorArgsIdxMap, + nextAnchorArgs, nextAnchorResults, nextAnchorResultsIdxMap, + inductionVars, originalOperandLoopArgsMap, + loopArgsOriginalOperandMap, forResultOrignalResultMap); + } else if (parallelAxis.size() - 1 == parallelIdx) { + + nxtFor = reductionAxisGenerateForLoop( + b, groupIdx, 0, parallelIdx + 1, nextAnchorArgsIdxMap, + nextAnchorArgs, originalOperandLoopArgsMap, + loopArgsOriginalOperandMap, nextAnchorResults, + nextAnchorResultsIdxMap, inductionVars, + forResultOrignalResultMap, originalResultForResultMap); + } // 3. move postOp to current body movePostOpToCurrentAnchor( @@ -1545,8 +1675,10 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( scf::ForOp ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { - auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; - auto multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; + MultiReductionCanonicalizer &rdCanonicalizer = + getMultiRdCanonicalizers()[grpIdx]; + vector::MultiDimReductionOp multiReductionOp = + rdCanonicalizer.getCandidateOps()[0]; std::queue &prevOps = rdCanonicalizer.getPrevOps(); std::queue &postOps = rdCanonicalizer.getPostOps(); std::queue &accRelatedOps = rdCanonicalizer.getAccRelatedOps(); @@ -1559,7 +1691,7 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { getPostOps(postOps, copyOpQueue, multiReductionOp); classifyAccRelatedOps(accRelatedOps, sourceRelatedOps, multiReductionOp.getAcc().getDefiningOp(), prevOps); - // move acc related operation to operation first + // move accumulate related operation to operation first std::queue rectifyQueue; DenseSet pushedSet; auto moveOperation = [&](std::queue &from, @@ -1611,14 +1743,13 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { originalResult.getDefiningOp(), forOp->getResults()[nextAnchorResultsIdxMap[x]], replaceIfFn); } - rewriter.replaceOp(getMultiRdCanonicalizers()[grpIdx].getCandidateOps()[0], - forOp); + rewriter.eraseOp(multiReductionOp); return forOp; } // generate simple data movement for loop -scf::ForOp ForLoopGenerator::generateScalarDataMovement( +scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const Location &loc, SmallVector &inductionVars, const ValueRange &iterArgs, DenseMap &tpAxisMap) { @@ -1674,14 +1805,200 @@ scf::ForOp ForLoopGenerator::generateScalarDataMovement( maybeYieldValue(b, loc, writeOp->getResults()); } else { // outter loop - auto nxtFor = - generateScalarDataMovement(b, grpIdx, forDimIdx + 1, loc, - inductionVars, loopState, tpAxisMap); + auto nxtFor = generateTransposeScalarDataMovement( + b, grpIdx, forDimIdx + 1, loc, inductionVars, loopState, + tpAxisMap); maybeYieldValue(b, loc, nxtFor->getResults()); } }); } +// generate simple data movement for loop +scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( + OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, + const size_t steps, const Location &loc, SmallVector &inductionVars, + const ValueRange &iterArgs) { + auto &scCanonicalizer = getShapeCastCanonicalizers()[grpIdx]; + vector::ShapeCastOp &scOp = scCanonicalizer.getCandidateOps()[0]; + VectorType sourceType = scOp.getSourceVectorType(); + VectorType destType = scOp.getResultVectorType(); + VectorType loopType = + sourceType.getRank() > destType.getRank() ? sourceType : destType; + size_t rank = loopType.getRank(); + + auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); + bool isLastDim = loopType.getRank() - 1 == (int64_t)forDimIdx; + auto forSteps = + makeIndexArithConstantOp(opBuilder, loc, isLastDim ? steps : 1); + auto numIter = + makeIndexArithConstantOp(opBuilder, loc, loopType.getShape()[forDimIdx]); + VectorType kernelType = + VectorType::get({(int64_t)steps}, loopType.getElementType()); + + // generate transpose for loop + return opBuilder.create( + loc, zero, numIter, forSteps, iterArgs, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { + inductionVars.emplace_back(iv); + + // inner most body of the loop + if (forDimIdx == rank - 1) { + sourceType.dump(); + destType.dump(); + // transfer read from source tensor + Value source = scOp->getOperand(0); + auto readSourceOp = + cast(source.getDefiningOp()); + vector::TransferWriteOp successorWriteOp; + for (Operation *x : scOp->getUsers()) { + if (isa(x)) { + successorWriteOp = cast(x); + break; + } + } + SmallVector exprs(loopType.getRank(), AffineExpr()); + bindSymbolsList(b.getContext(), exprs); + SmallVector operands{inductionVars.begin(), + inductionVars.end()}; + SmallVector smallRankShapeVars; + + auto getSmallRankShapeVars = [&](VectorType smallType) { + size_t itrIdx = 0; + SmallVector visitedAxis(rank, false); + while ((int64_t)itrIdx < smallType.getRank()) { + + size_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; + assert(endShape < rank and endShape >= 0 && "Invalid endShape"); + // skip non corresponding axis + // e.g.: vector<32x16x1x32xbf16> -> vector<1x512x32xbf16> + while (loopType.getShape()[endShape] > + smallType.getShape()[itrIdx]) { + endShape++; + } + const size_t expandIdx = endShape; + while (endShape < rank) { + visitedAxis[endShape] = true; + dimSize *= loopType.getShape()[endShape]; + if ((int64_t)dimSize == smallType.getShape()[itrIdx]) { + break; + } + endShape += 1; + } + const size_t expandSize = endShape - expandIdx + 1; + AffineExpr calculateOffset; + SmallVector offsetVars; + + for (size_t i = 0; i < expandSize; i++) { + size_t startIdx = i + 1; + size_t otherDimsSize = 1; + while (startIdx < expandSize) { + otherDimsSize *= (loopType.getShape()[startIdx + expandIdx]); + startIdx++; + } + AffineExpr dimSize = + getAffineConstantExpr(otherDimsSize, b.getContext()); + if (i == 0) { + calculateOffset = exprs[i] * dimSize; + } else { + calculateOffset = calculateOffset + exprs[i] * dimSize; + } + + offsetVars.emplace_back(inductionVars[i + expandIdx]); + } + AffineMap map = AffineMap::get(0, expandSize, calculateOffset); + + Value offset = + b.createOrFold(loc, map, offsetVars); + smallRankShapeVars.emplace_back(offset); + itrIdx++; + } + }; + + if (loopType == sourceType) { + getSmallRankShapeVars(destType); + } else { + getSmallRankShapeVars(sourceType); + } + + auto padValue = b.create( + loc, b.getZeroAttr(loopType.getElementType())); + + SmallVector inBoundsVal(1, true); + + auto transferReadOp = b.create( + loc, + /*vectorType=*/kernelType, + /*source=*/readSourceOp->getOperands()[0], + /*indices=*/loopType == sourceType ? inductionVars + : smallRankShapeVars, + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); + + auto writeOp = b.create( + loc, transferReadOp->getResults()[0], + successorWriteOp->getOperands()[1], + loopType == sourceType ? smallRankShapeVars : inductionVars, + inBoundsVal); + maybeYieldValue(b, loc, writeOp->getResults()); + } else { + // outter loop + auto nxtFor = generateShapeCastReadWriteLoop( + b, grpIdx, forDimIdx + 1, steps, loc, inductionVars, loopState); + maybeYieldValue(b, loc, nxtFor->getResults()); + } + }); +} + +/// generate transpose for loop +scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { + + ShapeCastCanonicalizer &scCanonicalizer = + getShapeCastCanonicalizers()[grpIdx]; + vector::ShapeCastOp &scOp = scCanonicalizer.getCandidateOps()[0]; + + VectorType sourceType = scOp.getSourceVectorType(); + VectorType destType = scOp.getResultVectorType(); + + OpBuilder b(scOp); + SmallVector iterArgs; + vector::TransferWriteOp successorWriteOp; + for (Operation *x : scOp->getUsers()) { + if (isa(x)) { + successorWriteOp = cast(x); + break; + } + } + iterArgs.emplace_back(successorWriteOp->getOperands()[1]); + SmallVector inductionVars; + IRRewriter rewriter(func); + const size_t groupStep = getFusionStrategy().getGroupMaxSteps()[grpIdx]; + + bool isSourceMultiple = + sourceType.getShape()[sourceType.getRank() - 1] % groupStep == 0; + bool isDestMultiple = + destType.getShape()[destType.getRank() - 1] % groupStep == 0; + + if (isDestMultiple and isSourceMultiple and + scCanonicalizer.isReadWriteOnLastDim()) { + scf::ForOp forOp = generateShapeCastReadWriteLoop( + b, grpIdx, 0, groupStep, scOp.getLoc(), inductionVars, iterArgs); + rewriter.replaceOp(successorWriteOp, forOp); + clearCurrentOperationGroup(grpIdx); + return forOp; + } + + // scalar data movement + scf::ForOp forOp = generateShapeCastReadWriteLoop( + b, grpIdx, 0, 1, scOp.getLoc(), inductionVars, iterArgs); + rewriter.replaceOp(successorWriteOp, forOp); + clearCurrentOperationGroup(grpIdx); + return forOp; +} + +void ForLoopGenerator::clearCurrentOperationGroup(size_t grpIdx) { + std::queue().swap(getFusionStrategy().getOpGroups()[grpIdx]); +}; + /// generate transpose for loop scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { @@ -1689,11 +2006,8 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { TransposeCanonicalizer &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; + VectorType vtType = tpOp.getVector().getType(); - std::cout << " _________ check tp operation source." - << "\n"; - vtType.dump(); - tpOp->getResultTypes()[0].dump(); size_t rank = vtType.getRank(); if (rank < 2) { llvm::llvm_unreachable_internal( @@ -1717,6 +2031,7 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { for (Operation *x : tpOp->getUsers()) { if (isa(x)) { successorWriteOp = cast(x); + break; } } iterArgs.emplace_back(successorWriteOp->getOperands()[1]); @@ -1724,18 +2039,14 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { IRRewriter rewriter(func); if (permuteSet.contains(rank - 1) and isTwoDTranspose) { - std::cout << " can use 16x16 : " << std::endl; scf::ForOp forOp = generateTransposeForLoopWithLastDim( b, grpIdx, 0, tpStep, tpOp.getLoc(), inductionVars, iterArgs); - for (Operation *x : tpOp->getUsers()) { - if (isa(x)) { - rewriter.replaceOp(x, forOp); - } - } + rewriter.replaceOp(successorWriteOp, forOp); + // clear current group operation + clearCurrentOperationGroup(grpIdx); return forOp; } - // findTransposeAxisMap(DenseMap & tpAxisMap); DenseMap tpAxisMap; size_t itrIdx = 0; while (itrIdx < rank) { @@ -1743,15 +2054,11 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { itrIdx++; } // scalar data movement - scf::ForOp forOp = generateScalarDataMovement( + scf::ForOp forOp = generateTransposeScalarDataMovement( b, grpIdx, 0, tpOp.getLoc(), inductionVars, iterArgs, tpAxisMap); - for (Operation *x : tpOp->getUsers()) { - if (isa(x)) { - rewriter.replaceOp(x, forOp); - } - } - std::cout << " scalar data movement." << std::endl; - forOp->dump(); + + rewriter.replaceOp(successorWriteOp, forOp); + clearCurrentOperationGroup(grpIdx); return forOp; } @@ -1766,6 +2073,7 @@ void MultiReductionCanonicalizer::initReductionAxis() { auto reductionRange = llvm::to_vector<4>(map_range( reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); reductionAxis.assign(reductionRange.begin(), reductionRange.end()); + llvm::sort(reductionAxis); } void MultiReductionCanonicalizer::initParallelAxis() { @@ -1776,7 +2084,7 @@ void MultiReductionCanonicalizer::initParallelAxis() { parallelAxis.push_back(i); } } - llvm::sort(parallelAxis.begin(), parallelAxis.end()); + llvm::sort(parallelAxis); } int64_t MultiReductionCanonicalizer::getTypeRank() { @@ -1810,6 +2118,14 @@ void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { getTypeRank(); getReductionAxisAndParallelAxis(); hasLastDimReduction(); + + // whether all the reduction axis is 1 + for (auto axis : reductionAxis) { + if (sourceType.getShape()[axis] != 1) { + isEmptyReduction = false; + break; + } + } }; void TransposeCanonicalizer::prepareSpecialOperationInfo() { @@ -1850,6 +2166,49 @@ bool TransposeCanonicalizer::isTwoDTranspose() { return diffCount == 2; } +bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { + vector::ShapeCastOp &shapeCastOp = getCandidateOps()[0]; + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType destType = shapeCastOp.getResultVectorType(); + VectorType smallRankType = + sourceType.getRank() > destType.getRank() ? destType : sourceType; + VectorType largeRankType = + sourceType.getRank() < destType.getRank() ? destType : sourceType; + SmallVector visitedAxis(largeRankType.getRank(), false); + // Map the index of the larger rank shape to the index of the smaller rank + // shape. + DenseMap> shapeIdxMap; + for (size_t i = 0; i < smallRankType.getRank(); i++) { + shapeIdxMap[i] = std::move(SmallVector()); + } + size_t itrIdx = 0; + while (itrIdx < smallRankType.getRank()) { + size_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; + assert(endShape < largeRankType.getRank() and endShape >= 0 && + "Invalid endShape"); + // skip non corresponding axis + // e.g.: vector<32x16x1x32xbf16> -> vector<1x512x32xbf16> + while (largeRankType.getShape()[endShape] > + smallRankType.getShape()[itrIdx]) { + endShape++; + } + while (endShape < largeRankType.getRank()) { + visitedAxis[endShape] = true; + shapeIdxMap[itrIdx].emplace_back(endShape); + dimSize *= largeRankType.getShape()[endShape]; + if ((int64_t)dimSize == smallRankType.getShape()[itrIdx]) { + break; + } + endShape++; + } + itrIdx++; + } + // check if the last dim is read write + SmallVector lastDims = shapeIdxMap[smallRankType.getRank() - 1]; + DenseSet set(lastDims.begin(), lastDims.end()); + return set.contains(largeRankType.getRank() - 1); +} + template void addDummyInit(SmallVector &canonicalizer) { canonicalizer.emplace_back(T({})); }; @@ -1909,6 +2268,8 @@ void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { getMultiRdCanonicalizers(); llvm::SmallVector &transposeCanonicalizers = getTransposeCanonicalizers(); + llvm::SmallVector &shapeCastCanonicalizers = + getShapeCastCanonicalizers(); for (auto [groupId, rdCanonicalizer] : llvm::enumerate(multiRdCanonicalizers)) { SmallVector &rdOps = @@ -1922,6 +2283,12 @@ void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { if (!transposeOps.empty()) { (void)generateTransposeForLoop(groupId); } + + SmallVector &shapeCastOps = + shapeCastCanonicalizers[groupId].getCandidateOps(); + if (!shapeCastOps.empty()) { + (void)generateShapeCastForLoop(groupId); + } } } @@ -1965,8 +2332,14 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. + printGroupOps(getFusionStrategy().getOpGroups()); + std::cout << "___________ before analysis ________________" + << "\n"; analysisGroupOperaion(); - // printGroupOps(fusionStrategy.getOpGroups()); + std::cout << "___________ after analysis ________________" + << "\n"; + printGroupOps(getFusionStrategy().getOpGroups()); + // Speical Operation Canonicalization canonicalizeSpecialOperation(); @@ -1998,6 +2371,12 @@ void CanonicalizerVectorOperation::run() { return false; } + // We don't need to vectorize the constant operation + if (isa(op)) { + LDBG("Operation is constantOp" << *op << "\n"); + return false; + } + if (mlir::isa(op) || mlir::isa(op)) { if (!isReadWriteOnLastDim(op)) { @@ -2100,15 +2479,26 @@ bool isSameVectorType(Operation *op1, Operation *op2) { return isSame; } +/// default op1 is previous operation bool VectorFusionStrategy::isCompatibleVectorType(Operation *op1, Operation *op2) { - auto type1 = getOperationVectorType(op1); - auto type2 = getOperationVectorType(op2); + // only lower to vector pass can produce read operation. In general two read + // operation is compatible + if (isa(op1) and isa(op2)) { + return true; + } + + mlir::FailureOr type1 = getOperationVectorType(op1, true); + mlir::FailureOr type2 = getOperationVectorType(op2, false); + // some operation has two different operands type like multireduction, we need + // to check whether compitable with accumulate vector + VectorType suppleType; if (failed(type1) || failed(type2)) { return false; } auto sp1 = type1.value(); auto sp2 = type2.value(); + // if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { // if (sp1.getRank() != sp2.getRank()) { // return false; @@ -2119,17 +2509,31 @@ bool VectorFusionStrategy::isCompatibleVectorType(Operation *op1, // } // } // } - bool isCompatible = true; - auto min_rank = std::min(sp1.getRank(), sp2.getRank()); - // from front to back - for (long i = 0; i < min_rank; i++) { - if (sp1.getDimSize(i) != sp2.getDimSize(i)) { - isCompatible = false; - break; + + auto isCompatible = [](VectorType sp1, VectorType sp2) { + bool isCompatible = true; + auto min_rank = std::min(sp1.getRank(), sp2.getRank()); + // from front to back + for (long i = 0; i < min_rank; i++) { + if (sp1.getDimSize(i) != sp2.getDimSize(i)) { + isCompatible = false; + break; + } } + return isCompatible; + }; + + bool result; + result = isCompatible(sp1, sp2); + // operand check only happen on later operation is op2 + // TODO: may need to support other similar operation like multireduction has + // two different operands type + if (isa(op2)) { + suppleType = cast(op2->getOperandTypes()[1]); + result |= isCompatible(suppleType, sp1); } - return isCompatible; + return result; } /// which axis do the shape cast in source shape a @@ -2259,11 +2663,13 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { }); } +/// whether two operation has data dependency +/// op1 default is previous operation, op2 default is current operation bool hasDataDependency(Operation *op1, Operation *op2) { if (!isSpecialOp(op1) and !isSpecialOp(op2)) { return false; } - // op1 must be special operation + // only special operation may cause data dependency if (!isSpecialOp(op1)) { return hasDataDependency(op2, op1); } @@ -2272,6 +2678,23 @@ bool hasDataDependency(Operation *op1, Operation *op2) { if (disableSpecialOp) { return true; } + + // if op1 is read the value and pass it to op2, it is not data dependency + if (isa(op2)) { + for (Value opd : op1->getOperands()) { + if (opd.getDefiningOp() == op2) { + return false; + } + } + } + // if op2 is write the result from op2, it is not data dependency + if (isa(op2)) { + Value opd = op2->getOperand(0); + if (opd.getDefiningOp() == op1) { + return false; + } + } + auto hasSameAxis = [](const SmallVector &dims1, const SmallVector &dims2) { DenseSet checkSet(dims2.begin(), dims2.end()); @@ -2353,7 +2776,6 @@ bool hasDataDependency(Operation *op1, Operation *op2) { SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); - return true; if (!isSpecialOp(op2)) { return hasSameAxis(dims1, dims2); } else { @@ -2404,15 +2826,14 @@ void VectorFusionStrategy::updateGroupBitgestVectorType(VectorType vectorType) { void VectorFusionStrategy::addOperationToGroup(Operation *op) { assert(op); - VectorType vectorType = getOperationVectorType(op).value(); + VectorType vectorType = getOperationMaxVectorType(op).value(); if (isNeedNewGroup(op)) { opGroups.emplace_back(std::queue()); - } else { - updateGroupBitgestVectorType(vectorType); } + updateGroupBitgestVectorType(vectorType); opGroups.back().push(op); opGroupIndexMap[op] = opGroups.size() - 1; - opAnchorPos[op] = getOperationVectorType(op)->getRank() - 1; + opAnchorPos[op] = getOperationMaxVectorType(op)->getRank() - 1; } // We classify the operations we are interested in after filtering. Operations @@ -2495,15 +2916,19 @@ void setOperationOperandResult(Operation *op, const VectorType &newOperandType, } }; +/// Reimplementation of writing a tensor from a constant of denseElementattr. void ForLoopGenerator::createNewConstantOp( - Operation *srcOp, vector::TransferWriteOp *transferWriteOp) { - auto &opPermuationMap = getOpPermuationMap(); + Operation *srcOp, vector::TransferWriteOp *transferWriteOp, + size_t groupSteps) { + DenseMap &opPermuationMap = getOpPermuationMap(); + IRRewriter srcWriter(srcOp); - auto newOperandType = getVectorzedType(mlir::cast(srcOp)); + VectorType newOperandType = + getVectorzedType(cast(srcOp), groupSteps); auto srcConstantOp = dyn_cast(srcOp); Operation *newConstantOp; - if (mlir::isa(srcConstantOp.getValue())) { - auto valueType = mlir::dyn_cast(srcConstantOp.getValue()); + if (isa(srcConstantOp.getValue())) { + auto valueType = dyn_cast(srcConstantOp.getValue()); if (valueType.isSplat()) { FailureOr res = createArithSplatConstantOp( srcWriter, srcOp->getLoc(), valueType, newOperandType); @@ -2512,20 +2937,26 @@ void ForLoopGenerator::createNewConstantOp( } newConstantOp = res.value().getDefiningOp(); } else { - newConstantOp = srcWriter.create( - srcOp->getLoc(), srcConstantOp.getValue()); + // TODO: need to test not splat value + // newConstantOp = srcWriter.create( + // srcOp->getLoc(), srcConstantOp.getValue()); + llvm::llvm_unreachable_internal( + "Can't support not splat constant value."); } newConstantOp->getResult(0).setType(newOperandType); transferWriteOp->setOperand(0, newConstantOp->getResult(0)); opPermuationMap.insert( - {mlir::cast(srcOp), transferWriteOp->getPermutationMap()}); + {*transferWriteOp, transferWriteOp->getPermutationMap()}); setOpVectorizationPermutationMap( - mlir::cast(srcOp), srcWriter, + *transferWriteOp, srcWriter, mlir::dyn_cast( transferWriteOp->getResults()[0].getType()), transferWriteOp->getPermutationMap()); + return; } + llvm::llvm_unreachable_internal( + "Can't support not DenseElementsAttr constant."); } /// Rewrite the operations in the group to vectorized form. @@ -2553,7 +2984,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( Operation *srcOp = transferWriteOp->getOperand(0).getDefiningOp(); if (mlir::isa(srcOp)) { - createNewConstantOp(srcOp, &transferWriteOp); + createNewConstantOp(srcOp, &transferWriteOp, groupSteps); } else { opPermuationMap.insert( {transferWriteOp, transferWriteOp.getPermutationMap()}); @@ -2661,7 +3092,8 @@ void VectorFusionStrategy::run() { classifyOperations(); } void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( Operation *sourceOp, DenseMap> &srcOpCanoniclizedMap, - size_t anchorPos, ReturnTypeKind retKind) { + size_t anchorPos, ReturnTypeKind retKind, + DenseMap &visitedOperation) { DenseMap &opGroupIndexMap = getFusionStrategy().getOpGroupIndexMap(); SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); @@ -2669,7 +3101,8 @@ void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( &groupOpResults = getGroupOpResults(); size_t sourceOpGid = opGroupIndexMap[sourceOp]; - auto [tsr, writeOpresult] = canonicalizeSourceOperation(sourceOp); + auto [tsr, writeOpresult] = + canonicalizeSourceOperation(sourceOp, visitedOperation); auto writeOp = writeOpresult.getDefiningOp(); srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); updateOpOperandResultInGroups(sourceOpGid, sourceOp, tsr, writeOpresult); @@ -2681,35 +3114,47 @@ void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( getOpPermuationMap()[writeOp] = writeOp.getPermutationMap(); } -void VectorOperationAnalyzer::analysisEmptyGroupAndMaxSteps() { - auto &groupOpResults = getGroupOpResults(); - auto &opGroups = getFusionStrategy().getOpGroups(); - - // If the group operations do not have result need to be returned, these are - // useless code. +void VectorOperationAnalyzer::analysisEmptyGroup() { + SmallVector, 8> &opGroups = + getFusionStrategy().getOpGroups(); + SmallVector>, 8> + &groupOpResults = getGroupOpResults(); for (auto [idx, grp] : llvm::enumerate(opGroups)) { + if (grp.empty()) { + continue; + } if (groupOpResults[idx].empty()) { std::queue().swap(grp); } + } +} + +/// get each operation in each group maximum support vectorization length +void VectorOperationAnalyzer::analysisGroupMaxSteps() { + auto &opGroups = getFusionStrategy().getOpGroups(); + + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + uint32_t steps = std::numeric_limits::max(); - auto &grpSteps = getFusionStrategy().getGroupMaxSteps(); - while (idx >= grpSteps.size()) { + llvm::SmallVector &grpSteps = + getFusionStrategy().getGroupMaxSteps(); + while (idx + 1 > grpSteps.size()) { grpSteps.emplace_back(steps); } std::queue tmpQueue(grp); auto calculateOpSteps = [&](Type type) { - auto opType = mlir::dyn_cast(type); + auto opType = dyn_cast(type); if (opType) steps = std::min(steps, (uint32_t)getDataTypeMAXSIMDLength(opType)); }; while (!tmpQueue.empty()) { auto op = tmpQueue.front(); tmpQueue.pop(); - if (mlir::isa(op)) { + if (isa(op)) { calculateOpSteps(op->getOperandTypes()[0]); } - calculateOpSteps(op->getResultTypes()[0]); + calculateOpSteps(getOperationVectorType(op).value()); } grpSteps[idx] = steps; } @@ -2731,11 +3176,16 @@ void VectorOperationAnalyzer::specialOperationAnchorRectify() { } } -// analysis operation result of current group whether needed by other -// operation which out of current group +/// analysis operation result of current group whether needed by other +/// operation which out of current group void VectorOperationAnalyzer::analysisGroupOperationResults() { DenseMap> srcOpCanoniclizedMap; + // record the operation which has been moved DenseSet movedOperationSet; + // record the operation's position which has visited, inorder to ensure set + // correct operand + size_t opCounter = 0; + DenseMap visitedOperation; DenseMap &opGroupIndexMap = getFusionStrategy().getOpGroupIndexMap(); SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); @@ -2764,7 +3214,11 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { } }; + analysisGroupMaxSteps(); + func.walk([&](Operation *op) { + visitedOperation.insert({op, opCounter++}); + for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { Operation *sourceOp = opd.getDefiningOp(); if (opGroupIndexMap.contains(sourceOp)) { @@ -2795,7 +3249,8 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { } if (!srcOpCanoniclizedMap.contains(sourceOp)) { generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap, - OpAnchorPos[sourceOp], rtKind); + OpAnchorPos[sourceOp], rtKind, + visitedOperation); } else { // udpate result return type updateReturnResultKind(sourceOp, sourceOpGid, rtKind); @@ -2821,11 +3276,22 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { } } } else if (isa_and_nonnull(sourceOp)) { + if (!opGroupIndexMap.contains(op)) { + continue; + } + // TODO: add more operation to this case, write to constant value need + // to do this + if (isa(op)) { + if (idx == 0) + continue; + } + auto constantOp = cast(sourceOp); IRRewriter rewriter(constantOp); - if (mlir::isa(constantOp.getValue())) { + if (isa(constantOp.getValue())) { if (!srcOpCanoniclizedMap.contains(sourceOp)) { - auto [tsr, writeOpresult] = canonicalizeSourceOperation(sourceOp); + auto [tsr, writeOpresult] = + canonicalizeSourceOperation(sourceOp, visitedOperation); srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); } auto opInit = canonicalizeCurrentOperation( @@ -2840,7 +3306,7 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { movedOperationSet.insert(op); } }); - analysisEmptyGroupAndMaxSteps(); + analysisEmptyGroup(); specialOperationAnchorRectify(); #undef RESULT_RETURN_TYPE LDBG("Complete analysis group operation results\n"); @@ -2864,7 +3330,7 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( operandIdxMap[x] = operands.size() - 1; } ValueRange forIterArgs(operands); - auto shapes = vectorType.getShape(); + ArrayRef shapes = vectorType.getShape(); SmallVector inductionVars; // generate for loop auto forOp = constructNestedForOp( @@ -2944,6 +3410,7 @@ struct CPUPhysicalRegisterPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); auto func = getOperation(); + if (hasNotSupportOperation(&func)) { LDBG("Not support operation appears in current function."); return; diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 60d572d2d..b09ff01b3 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -143,10 +143,9 @@ func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf #map = affine_map<(d0) -> (d0 * 64)> #map1 = affine_map<(d0) -> (d0 * 128)> -#map2 = affine_map<(d0) -> (d0 floordiv 16)> -#map3 = affine_map<(d0) -> (d0 floordiv 32)> -#map4 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 128)> -#map5 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2 * 64)> +#map2 = affine_map<(d0) -> (d0 * 4)> +#map3 = affine_map<(d0) -> (d0 floordiv 16)> +#map4 = affine_map<(d0) -> (d0 floordiv 32)> func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { %c32 = arith.constant 32 : index %c512 = arith.constant 512 : index @@ -160,58 +159,210 @@ func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf %3 = affine.apply #map(%arg3) %4 = affine.apply #map1(%arg4) %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> - %extracted_slice_0 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %5:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_0, %arg10 = %extracted_slice_0, %arg11 = %extracted_slice_0) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %6:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %7:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %extracted_slice_1 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> - %extracted_slice_2 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> - %8:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_2, %arg22 = %arg18, %arg23 = %arg19) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %9:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %10:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> - %11 = affine.apply #map2(%arg28) - %12 = affine.apply #map3(%arg24) - %extracted_slice_6 = tensor.extract_slice %arg1[%11, %12, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x1x16x32xbf16> - %extracted_slice_7 = tensor.extract_slice %1[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x256xbf16> to tensor<512x32xbf16> - %unpack = tensor.unpack %extracted_slice_6 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_7 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> - %extracted_slice_8 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %13 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> - %expanded_9 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> - %14 = linalg.batch_reduce_matmul ins(%expanded, %expanded_9 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%13 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %15 = affine.apply #map4(%arg12, %arg24, %arg4) - %16 = affine.apply #map5(%arg8, %arg20, %arg3) - %extracted_slice_10 = tensor.extract_slice %arg2[%15] [32] [1] : tensor<256xbf16> to tensor<32xbf16> - %extracted_slice_11 = tensor.extract_slice %0[%16, %15] [32, 32] [1, 1] : tensor<128x256xbf16> to tensor<32x32xbf16> - %broadcasted = linalg.broadcast ins(%extracted_slice_10 : tensor<32xbf16>) outs(%extracted_slice_11 : tensor<32x32xbf16>) dimensions = [0] - %extracted_slice_12 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %17 = linalg.add ins(%14, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_12 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_13 = tensor.insert_slice %14 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %extracted_slice_14 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %18 = linalg.exp ins(%17 : tensor<32x32xbf16>) outs(%extracted_slice_14 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_15 = tensor.insert_slice %17 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %inserted_slice_16 = tensor.insert_slice %18 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - scf.yield %inserted_slice_13, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + %5 = affine.apply #map2(%arg4) + %extracted_slice_0 = tensor.extract_slice %arg1[0, %5, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x4x16x32xbf16> + %extracted_slice_1 = tensor.extract_slice %1[0, %4] [512, 128] [1, 1] : tensor<512x256xbf16> to tensor<512x128xbf16> + %6 = affine.apply #map(%arg3) + %7 = affine.apply #map1(%arg4) + %extracted_slice_2 = tensor.extract_slice %arg5[%6, %7] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %8 = affine.apply #map(%arg3) + %9 = affine.apply #map1(%arg4) + %10 = affine.apply #map(%arg3) + %11 = affine.apply #map1(%arg4) + %12 = affine.apply #map1(%arg4) + %13 = affine.apply #map(%arg3) + %14 = affine.apply #map1(%arg4) + %extracted_slice_3 = tensor.extract_slice %arg2[%12] [128] [1] : tensor<256xbf16> to tensor<128xbf16> + %extracted_slice_4 = tensor.extract_slice %0[%13, %14] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_5 = tensor.extract_slice %arg6[%10, %11] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %15 = affine.apply #map(%arg3) + %16 = affine.apply #map1(%arg4) + %17 = affine.apply #map(%arg3) + %18 = affine.apply #map1(%arg4) + %extracted_slice_6 = tensor.extract_slice %arg7[%17, %18] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %19:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_2, %arg10 = %extracted_slice_5, %arg11 = %extracted_slice_6) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %22:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %23:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_7 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> + %24 = affine.apply #map3(%arg16) + %25 = affine.apply #map4(%arg12) + %extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%24, %25, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x4x16x32xbf16> + %extracted_slice_9 = tensor.extract_slice %extracted_slice_1[%arg16, %arg12] [512, 128] [1, 1] : tensor<512x128xbf16> to tensor<512x128xbf16> + %extracted_slice_10 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_11 = tensor.extract_slice %extracted_slice_3[%arg12] [128] [1] : tensor<128xbf16> to tensor<128xbf16> + %extracted_slice_12 = tensor.extract_slice %extracted_slice_4[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_13 = tensor.extract_slice %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_14 = tensor.extract_slice %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %26:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_10, %arg22 = %extracted_slice_13, %arg23 = %extracted_slice_14) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %27:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %28:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_17 = tensor.extract_slice %extracted_slice_7[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> + %29 = affine.apply #map3(%arg28) + %30 = affine.apply #map4(%arg24) + %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%29, %30, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> + %extracted_slice_19 = tensor.extract_slice %extracted_slice_9[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> + %unpack = tensor.unpack %extracted_slice_18 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_19 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> + %extracted_slice_20 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %31 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %expanded = tensor.expand_shape %extracted_slice_17 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> + %expanded_21 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> + %32 = linalg.batch_reduce_matmul ins(%expanded, %expanded_21 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%31 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %extracted_slice_22 = tensor.extract_slice %extracted_slice_11[%arg24] [32] [1] : tensor<128xbf16> to tensor<32xbf16> + %extracted_slice_23 = tensor.extract_slice %extracted_slice_12[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %broadcasted = linalg.broadcast ins(%extracted_slice_22 : tensor<32xbf16>) outs(%extracted_slice_23 : tensor<32x32xbf16>) dimensions = [0] + %extracted_slice_24 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %33 = linalg.add ins(%32, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_24 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_25 = tensor.insert_slice %32 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %extracted_slice_26 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %34 = linalg.exp ins(%33 : tensor<32x32xbf16>) outs(%extracted_slice_26 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_27 = tensor.insert_slice %33 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %inserted_slice_28 = tensor.insert_slice %34 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice_25, %inserted_slice_27, %inserted_slice_28 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - scf.yield %10#0, %10#1, %10#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %28#0, %28#1, %28#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - scf.yield %9#0, %9#1, %9#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %27#0, %27#1, %27#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - %inserted_slice = tensor.insert_slice %8#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_3 = tensor.insert_slice %8#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_4 = tensor.insert_slice %8#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - scf.yield %inserted_slice, %inserted_slice_3, %inserted_slice_4 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + %inserted_slice = tensor.insert_slice %26#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_15 = tensor.insert_slice %26#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_16 = tensor.insert_slice %26#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %23#0, %23#1, %23#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - scf.yield %6#0, %6#1, %6#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %22#0, %22#1, %22#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } + %20 = affine.apply #map(%arg3) + %21 = affine.apply #map1(%arg4) scf.forall.in_parallel { - tensor.parallel_insert_slice %5#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %19#2 into %arg7[%20, %21] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %19#1 into %arg6[%15, %16] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %19#0 into %arg5[%8, %9] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> } } return %2#2 : tensor<128x256xbf16> } + + + + func.func @main_entry(%arg0: tensor<128x128x64x64xbf16>, %arg1: tensor<128x128x32x64x2xbf16>) -> tensor<128x128x64x64xbf16> attributes {llvm.emit_c_interface} { + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c32 = arith.constant 32 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x128x64x64xbf16> + %1 = tensor.empty() : tensor<128x128x64x64xf32> + %2 = tensor.empty() : tensor<2x1x1x128x128x64x64xf32> + %3 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %2) -> (tensor<2x1x1x128x128x64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg3[%arg2, 0, 0, 0, 0, 0, 0] [1, 1, 1, 128, 128, 64, 64] [1, 1, 1, 1, 1, 1, 1] : tensor<2x1x1x128x128x64x64xf32> to tensor<128x128x64x64xf32> + %5 = scf.forall (%arg4) in (7) shared_outs(%arg5 = %extracted_slice) -> (tensor<128x128x64x64xf32>) { + %6 = affine.min affine_map<(d0) -> (d0 * -19 + 128, 19)>(%arg4) + %7 = affine.max affine_map<(d0) -> (0, d0)>(%6) + %8 = affine.apply affine_map<(d0) -> (d0 * 19)>(%arg4) + %extracted_slice_0 = tensor.extract_slice %arg5[%8, 0, 0, 0] [%7, 128, 64, 64] [1, 1, 1, 1] : tensor<128x128x64x64xf32> to tensor + %9 = scf.forall (%arg6) in (4) shared_outs(%arg7 = %extracted_slice_0) -> (tensor) { + %11 = affine.max affine_map<(d0) -> (0, d0)>(%6) + %12 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6) + %extracted_slice_1 = tensor.extract_slice %arg7[0, %12, 0, 0] [%11, 32, 64, 64] [1, 1, 1, 1] : tensor to tensor + %13 = scf.for %arg8 = %c0 to %11 step %c4 iter_args(%arg9 = %extracted_slice_1) -> (tensor) { + %15 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%arg8)[%11] + %extracted_slice_2 = tensor.extract_slice %arg9[%arg8, 0, 0, 0] [%15, 32, 64, 64] [1, 1, 1, 1] : tensor to tensor + %16 = scf.for %arg10 = %c0 to %c32 step %c4 iter_args(%arg11 = %extracted_slice_2) -> (tensor) { + %extracted_slice_3 = tensor.extract_slice %arg11[0, %arg10, 0, 0] [%15, 4, 64, 64] [1, 1, 1, 1] : tensor to tensor + %17 = scf.for %arg12 = %c0 to %c64 step %c16 iter_args(%arg13 = %extracted_slice_3) -> (tensor) { + %extracted_slice_5 = tensor.extract_slice %arg13[0, 0, 0, 0] [%15, 4, 64, 64] [1, 1, 1, 1] : tensor to tensor + %18 = scf.for %arg14 = %c0 to %15 step %c1 iter_args(%arg15 = %extracted_slice_5) -> (tensor) { + %extracted_slice_7 = tensor.extract_slice %arg15[%arg14, 0, 0, 0] [1, 4, 64, 64] [1, 1, 1, 1] : tensor to tensor<1x4x64x64xf32> + %19 = scf.for %arg16 = %c0 to %c4 step %c1 iter_args(%arg17 = %extracted_slice_7) -> (tensor<1x4x64x64xf32>) { + %20 = affine.apply affine_map<(d0)[s0, s1] -> (d0 * 19 + s0 + s1)>(%arg4)[%arg14, %arg8] + %21 = affine.apply affine_map<(d0)[s0] -> (d0 * 64 + s0)>(%arg2)[%arg12] + %extracted_slice_9 = tensor.extract_slice %arg0[%20, %21, 0, 0] [1, 16, 64, 64] [1, 1, 1, 1] : tensor<128x128x64x64xbf16> to tensor<16x64x64xbf16> + %22 = affine.apply affine_map<(d0)[s0, s1] -> (d0 * 32 + s0 + s1)>(%arg6)[%arg16, %arg10] + %23 = affine.apply affine_map<(d0)[s0] -> (d0 * 64 + s0)>(%arg2)[%arg12] + %extracted_slice_10 = tensor.extract_slice %arg1[%22, %23, 0, 0, 0] [1, 16, 32, 64, 2] [1, 1, 1, 1, 1] : tensor<128x128x32x64x2xbf16> to tensor<16x32x64x2xbf16> + %extracted_slice_11 = tensor.extract_slice %arg17[0, %arg16, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x4x64x64xf32> to tensor<64x64xf32> + %24 = arith.cmpi eq, %arg12, %c0 : index + %25 = scf.if %24 -> (tensor<64x64xf32>) { + %26 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_11 : tensor<64x64xf32>) -> tensor<64x64xf32> + %27 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_9, %extracted_slice_10 : tensor<16x64x64xbf16>, tensor<16x32x64x2xbf16>) outs(%26 : tensor<64x64xf32>) -> tensor<64x64xf32> + scf.yield %27 : tensor<64x64xf32> + } else { + %26 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_9, %extracted_slice_10 : tensor<16x64x64xbf16>, tensor<16x32x64x2xbf16>) outs(%extracted_slice_11 : tensor<64x64xf32>) -> tensor<64x64xf32> + scf.yield %26 : tensor<64x64xf32> + } + %inserted_slice_12 = tensor.insert_slice %25 into %arg17[0, %arg16, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<64x64xf32> into tensor<1x4x64x64xf32> + scf.yield %inserted_slice_12 : tensor<1x4x64x64xf32> + } + %inserted_slice_8 = tensor.insert_slice %19 into %arg15[%arg14, 0, 0, 0] [1, 4, 64, 64] [1, 1, 1, 1] : tensor<1x4x64x64xf32> into tensor + scf.yield %inserted_slice_8 : tensor + } + %inserted_slice_6 = tensor.insert_slice %18 into %arg13[0, 0, 0, 0] [%15, 4, 64, 64] [1, 1, 1, 1] : tensor into tensor + scf.yield %inserted_slice_6 : tensor + } + %inserted_slice_4 = tensor.insert_slice %17 into %arg11[0, %arg10, 0, 0] [%15, 4, 64, 64] [1, 1, 1, 1] : tensor into tensor + scf.yield %inserted_slice_4 : tensor + } + %inserted_slice = tensor.insert_slice %16 into %arg9[%arg8, 0, 0, 0] [%15, 32, 64, 64] [1, 1, 1, 1] : tensor into tensor + scf.yield %inserted_slice : tensor + } + %14 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6) + scf.forall.in_parallel { + tensor.parallel_insert_slice %13 into %arg7[0, %14, 0, 0] [%11, 32, 64, 64] [1, 1, 1, 1] : tensor into tensor + } + } + %10 = affine.apply affine_map<(d0) -> (d0 * 19)>(%arg4) + scf.forall.in_parallel { + tensor.parallel_insert_slice %9 into %arg5[%10, 0, 0, 0] [%7, 128, 64, 64] [1, 1, 1, 1] : tensor into tensor<128x128x64x64xf32> + } + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg3[%arg2, 0, 0, 0, 0, 0, 0] [1, 1, 1, 128, 128, 64, 64] [1, 1, 1, 1, 1, 1, 1] : tensor<128x128x64x64xf32> into tensor<2x1x1x128x128x64x64xf32> + } + } + %4 = scf.forall (%arg2) in (128) shared_outs(%arg3 = %0) -> (tensor<128x128x64x64xbf16>) { + %extracted_slice = tensor.extract_slice %1[%arg2, 0, 0, 0] [1, 128, 64, 64] [1, 1, 1, 1] : tensor<128x128x64x64xf32> to tensor<1x128x64x64xf32> + %extracted_slice_0 = tensor.extract_slice %arg3[%arg2, 0, 0, 0] [1, 128, 64, 64] [1, 1, 1, 1] : tensor<128x128x64x64xbf16> to tensor<1x128x64x64xbf16> + %5:2 = scf.for %arg4 = %c0 to %c128 step %c1 iter_args(%arg5 = %extracted_slice, %arg6 = %extracted_slice_0) -> (tensor<1x128x64x64xf32>, tensor<1x128x64x64xbf16>) { + %extracted_slice_1 = tensor.extract_slice %arg5[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x128x64x64xf32> to tensor<1x1x64x64xf32> + %extracted_slice_2 = tensor.extract_slice %arg6[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x128x64x64xbf16> to tensor<1x1x64x64xbf16> + %6:2 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %extracted_slice_1, %arg9 = %extracted_slice_2) -> (tensor<1x1x64x64xf32>, tensor<1x1x64x64xbf16>) { + %extracted_slice_4 = tensor.extract_slice %arg8[0, 0, %arg7, 0] [1, 1, 1, 64] [1, 1, 1, 1] : tensor<1x1x64x64xf32> to tensor<1x1x1x64xf32> + %extracted_slice_5 = tensor.extract_slice %arg9[0, 0, %arg7, 0] [1, 1, 1, 64] [1, 1, 1, 1] : tensor<1x1x64x64xbf16> to tensor<1x1x1x64xbf16> + %7:2 = scf.for %arg10 = %c0 to %c64 step %c32 iter_args(%arg11 = %extracted_slice_4, %arg12 = %extracted_slice_5) -> (tensor<1x1x1x64xf32>, tensor<1x1x1x64xbf16>) { + %extracted_slice_8 = tensor.extract_slice %arg11[0, 0, 0, %arg10] [1, 1, 1, 32] [1, 1, 1, 1] : tensor<1x1x1x64xf32> to tensor<1x1x1x32xf32> + %8 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<1x1x1x32xf32>) -> tensor<1x1x1x32xf32> + %9 = scf.for %arg13 = %c0 to %c2 step %c1 iter_args(%arg14 = %8) -> (tensor<1x1x1x32xf32>) { + %extracted_slice_12 = tensor.extract_slice %3[%arg13, 0, 0, %arg2, %arg4, %arg7, %arg10] [1, 1, 1, 1, 1, 1, 32] [1, 1, 1, 1, 1, 1, 1] : tensor<2x1x1x128x128x64x64xf32> to tensor<1x1x1x1x1x1x32xf32> + %reduced = linalg.reduce ins(%extracted_slice_12 : tensor<1x1x1x1x1x1x32xf32>) outs(%arg14 : tensor<1x1x1x32xf32>) dimensions = [0, 1, 2] + (%in: f32, %init: f32) { + %11 = arith.addf %in, %init : f32 + linalg.yield %11 : f32 + } + scf.yield %reduced : tensor<1x1x1x32xf32> + } + %extracted_slice_9 = tensor.extract_slice %arg12[0, 0, 0, %arg10] [1, 1, 1, 32] [1, 1, 1, 1] : tensor<1x1x1x64xbf16> to tensor<1x1x1x32xbf16> + %10 = linalg.copy ins(%9 : tensor<1x1x1x32xf32>) outs(%extracted_slice_9 : tensor<1x1x1x32xbf16>) -> tensor<1x1x1x32xbf16> + %inserted_slice_10 = tensor.insert_slice %9 into %arg11[0, 0, 0, %arg10] [1, 1, 1, 32] [1, 1, 1, 1] : tensor<1x1x1x32xf32> into tensor<1x1x1x64xf32> + %inserted_slice_11 = tensor.insert_slice %10 into %arg12[0, 0, 0, %arg10] [1, 1, 1, 32] [1, 1, 1, 1] : tensor<1x1x1x32xbf16> into tensor<1x1x1x64xbf16> + scf.yield %inserted_slice_10, %inserted_slice_11 : tensor<1x1x1x64xf32>, tensor<1x1x1x64xbf16> + } + %inserted_slice_6 = tensor.insert_slice %7#0 into %arg8[0, 0, %arg7, 0] [1, 1, 1, 64] [1, 1, 1, 1] : tensor<1x1x1x64xf32> into tensor<1x1x64x64xf32> + %inserted_slice_7 = tensor.insert_slice %7#1 into %arg9[0, 0, %arg7, 0] [1, 1, 1, 64] [1, 1, 1, 1] : tensor<1x1x1x64xbf16> into tensor<1x1x64x64xbf16> + scf.yield %inserted_slice_6, %inserted_slice_7 : tensor<1x1x64x64xf32>, tensor<1x1x64x64xbf16> + } + %inserted_slice = tensor.insert_slice %6#0 into %arg5[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x1x64x64xf32> into tensor<1x128x64x64xf32> + %inserted_slice_3 = tensor.insert_slice %6#1 into %arg6[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x1x64x64xbf16> into tensor<1x128x64x64xbf16> + scf.yield %inserted_slice, %inserted_slice_3 : tensor<1x128x64x64xf32>, tensor<1x128x64x64xbf16> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %5#1 into %arg3[%arg2, 0, 0, 0] [1, 128, 64, 64] [1, 1, 1, 1] : tensor<1x128x64x64xbf16> into tensor<128x128x64x64xbf16> + } + } + return %4 : tensor<128x128x64x64xbf16> + } + From 4c2b3b8bfdcc2ec44f2afd3a13f271d8bfa63691 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Wed, 31 Jul 2024 09:19:24 +0800 Subject: [PATCH 20/66] fix bugs --- include/gc/Transforms/TilingVector.h | 21 +- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 491 ++++++++++++------ lib/gc/Transforms/LowerTileVectorPass.cpp | 35 ++ lib/gc/Transforms/Pipeline.cpp | 6 +- .../gc/transforms/cpu-vetor-distribution.mlir | 227 ++------ 5 files changed, 454 insertions(+), 326 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index b65a91259..b527a73dd 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -18,6 +18,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" @@ -43,6 +44,7 @@ #include #include #include +#include #include #include #include @@ -54,6 +56,7 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); void setOperationCorrectOperand( Operation *op, const ValueRange &iterArgs, const llvm::DenseMap &operandIdxMap, + DenseMap &originalOperandLoopArgsMap, ArrayRef inductionVars, const llvm::DenseMap &opPermuationMap); mlir::FailureOr getOperationOperateTensor(Operation *op); @@ -441,11 +444,18 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const ValueRange &iterArgs, VectorType type, const llvm::ArrayRef &dims, llvm::SmallVector &inductionVars, - const llvm::DenseMap &operandIdxMap); + llvm::DenseMap &operandIdxMap, + DenseMap &originalOperandMap, + DenseMap &operandOriginalMap, + llvm::SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap); void moveOperationsToCurrentForBody( const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, const llvm::DenseMap &operandIdxMap, - const ValueRange &loopState, std::queue &queue); + const ValueRange &loopState, + DenseMap &originalOperandLoopArgsMap, + std::queue &queue); void getResultInCurrentOps(const size_t anchorIdx, const size_t groupId, const std::queue ops, @@ -479,6 +489,8 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { std::queue &movedOperation, ArrayRef inductionVars, const llvm::DenseMap &operandIdxMap, const ValueRange &loopState, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap, const llvm::SmallVector &nextAnchorResults, DenseMap &forResultOrignalResultMap); @@ -535,7 +547,8 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { scf::ForOp generateTransposeForLoopWithLastDim( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const int tpSteps, const Location &loc, SmallVector &inductionVars, - const ValueRange &iterArgs); + const ValueRange &iterArgs, DenseMap &operandIdxMap, + DenseMap &originalOperandMap); scf::ForOp generateTransposeScalarDataMovement( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, @@ -566,7 +579,7 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { void analysisGroupMaxSteps(); void analysisGroupOperaion(); void analysisGroupOperationResults(); - void specialOperationAnchorRectify(); + void specialOperationRectify(DenseMap &visitedOperation); }; /// Vectorize vector operation with target machines simd instructions. class CanonicalizerVectorOperation : virtual public ForLoopGenerator, diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index cd262a299..db14a1999 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -31,15 +31,16 @@ void printGroupOps(SmallVector, 8> &opGroups) { if (grp.empty()) { continue; } - std::cout << "__________________ group start_____________" << std::endl; + llvm::outs() << "__________________ group start_____________" + << "\n"; std::queue tmpQ(grp); while (!tmpQ.empty()) { auto cur = tmpQ.front(); tmpQ.pop(); cur->dump(); } - std::cout << "__________________ group end_____________" << std::endl; - std::cout << std::endl; + llvm::outs() << "__________________ group end_____________" + << "\n"; } } @@ -57,6 +58,25 @@ void printQueue(const std::queue &opQueue) { "__________________" << std::endl; } + +/// whether op2 use op1 result +/// Currently we just enable this function for write and read operation +template || + std::is_same_v, + T>> +static bool isOperationsHasDefUseRelation(Operation *op1, Operation *op2) { + if (!isa(op1) and !isa(op2)) { + return false; + } + for (Value opd : op2->getOperands()) { + if (opd.getDefiningOp() == op1) { + return true; + } + } + + return false; +} /// Get the index position of the first element that is true static size_t getFirstTrueIndex(ArrayRef ararys) { for (size_t i = 0; i < ararys.size(); i++) { @@ -656,20 +676,25 @@ Type getScalarType(Operation *op) { } Operation *createTensorEmptyBefore(Operation *op) { - auto rtType = mlir::dyn_cast(op->getResultTypes()[0]); + + auto rtType = dyn_cast(op->getResultTypes()[0]); IRRewriter reWriter(op); + Block *block = op->getBlock(); + + reWriter.setInsertionPoint(block, block->getOperations().begin()); SmallVector shapes; SmallVector dynDims; for (unsigned i = 0; i < rtType.getRank(); i++) { shapes.push_back(rtType.getDimSize(i)); if (rtType.isDynamicDim(i)) { - dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), - op->getResult(0), i)); + dynDims.push_back( + reWriter.create(op->getLoc(), op->getResult(0), i)); } } - return reWriter.create(op->getLoc(), rtType.getShape(), - rtType.getElementType(), dynDims); + auto emtpyOp = reWriter.create( + op->getLoc(), rtType.getShape(), rtType.getElementType(), dynDims); + return emtpyOp; } /// get the tensor that operation should write into @@ -966,6 +991,7 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, void ForLoopGenerator::moveOperationsToCurrentForBody( const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, const DenseMap &operandIdxMap, const ValueRange &loopState, + DenseMap &originalOperandLoopArgsMap, std::queue &opQueue) { auto &opPermuationMap = getOpPermuationMap(); auto tmpQ(opQueue); @@ -974,7 +1000,8 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( tmpQ.pop(); x->moveBefore(b.getBlock(), b.getBlock()->end()); // check operation type to set correct operand - setOperationCorrectOperand(x, loopState, operandIdxMap, inductionVars, + setOperationCorrectOperand(x, loopState, operandIdxMap, + originalOperandLoopArgsMap, inductionVars, opPermuationMap); } } @@ -1019,20 +1046,20 @@ void ForLoopGenerator::getResultInCurrentOps( /// update loop args related status /// \param nextAnchorArgsIdxMap anchor args index map -/// \param nextOperandArgsMap original value to next loop args map -/// \param nextArgsOperandMap next loop args to original value map +/// \param nextOriginalOperandMap original value to next loop args map +/// \param nextOperandOriginalMap next loop args to original value map void updateCurrentArgsStatus(const ValueRange &loopState, const size_t loopStateIdx, SmallVector &nextAnchorArgs, Value originalValue, DenseMap &nextAnchorArgsIdxMap, - DenseMap &nextOperandArgsMap, - DenseMap &nextArgsOperandMap) { + DenseMap &nextOriginalOperandMap, + DenseMap &nextOperandOriginalMap) { Value currentArgs = loopState[loopStateIdx]; nextAnchorArgs.emplace_back(currentArgs); nextAnchorArgsIdxMap[currentArgs] = nextAnchorArgs.size() - 1; - nextOperandArgsMap[originalValue] = currentArgs; - nextArgsOperandMap[currentArgs] = originalValue; + nextOriginalOperandMap[originalValue] = currentArgs; + nextOperandOriginalMap[currentArgs] = originalValue; } void ForLoopGenerator::getInitArgsToNextAnchor( @@ -1050,7 +1077,7 @@ void ForLoopGenerator::getInitArgsToNextAnchor( DenseSet visited; // find the next anchor arguments std::queue tmpQ(nextOperations); - DenseMap nextOperandArgsMap, nextArgsOperandMap; + DenseMap nextOriginalArgsMap, nextOperandOriginalMap; while (!tmpQ.empty()) { Operation *cur = tmpQ.front(); @@ -1062,14 +1089,14 @@ void ForLoopGenerator::getInitArgsToNextAnchor( int loopStateIdx = currentLoopStateIdxMap[originalOperandLoopArgsMap[x]]; updateCurrentArgsStatus(loopState, loopStateIdx, nextAnchorArgs, x, - nextAnchorArgsIdxMap, nextOperandArgsMap, - nextArgsOperandMap); + nextAnchorArgsIdxMap, nextOriginalArgsMap, + nextOperandOriginalMap); visited.insert(x); } } } - originalOperandLoopArgsMap = nextOperandArgsMap; - loopArgsOriginalOperandMap = nextArgsOperandMap; + originalOperandLoopArgsMap = nextOriginalArgsMap; + loopArgsOriginalOperandMap = nextOperandOriginalMap; } void ForLoopGenerator::getOperationInCurrentAnchor( @@ -1112,9 +1139,11 @@ void ForLoopGenerator::replaceOperationsWithForLoopResult( } } -/// \param [out] nextLoopStateidxMap +/// \param [out] nextAnchorArgsIdxMap /// \param [out] nextAnchorArgs /// \param [out] movingQueue +/// \param [in, out] originalOperandLoopArgsMap +/// \param [in, out] LoopArgsoriginalOperandMap void ForLoopGenerator::movePreOpToCurrentAnchor( const size_t anchorIdx, const size_t groupIdx, OpBuilder &b, ArrayRef inductionVars, const ValueRange &loopState, @@ -1130,19 +1159,19 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( std::queue movingOperation; getOperationInCurrentAnchor(anchorIdx, candidateQueue, movingOperation); - // 2. get next anchor args - getInitArgsToNextAnchor(anchorIdx, groupIdx, candidateQueue, loopState, - currentLoopStateIdxMap, nextAnchorArgsIdxMap, - nextAnchorArgs, originalOperandLoopArgsMap, - LoopArgsoriginalOperandMap); - - // 3. rewrite operation as vectorize IR + // 2. rewrite operation as vectorize IR rewriteOperationAsVectorize(b, groupIdx, &movingOperation); - // 4. move opeartions to current for block + // 3. move opeartions to current for block moveOperationsToCurrentForBody(groupIdx, b, inductionVars, currentLoopStateIdxMap, loopState, - movingOperation); + originalOperandLoopArgsMap, movingOperation); + + // 4. get next anchor args + getInitArgsToNextAnchor(anchorIdx, groupIdx, candidateQueue, loopState, + currentLoopStateIdxMap, nextAnchorArgsIdxMap, + nextAnchorArgs, originalOperandLoopArgsMap, + LoopArgsoriginalOperandMap); // 5. move operations to moved queue while (!movingOperation.empty()) { @@ -1156,7 +1185,10 @@ void ForLoopGenerator::movePostOpToCurrentAnchor( const ValueRange &forResults, const Block *forBlock, std::queue &candidateOps, std::queue &movedOps, ArrayRef inductionVars, const DenseMap &operandIdxMap, - const ValueRange &loopState, const SmallVector &nextAnchorResults, + const ValueRange &loopState, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap, + const SmallVector &nextAnchorResults, DenseMap &forResultOrignalResultMap) { // 1. move post-op to current loop body @@ -1166,7 +1198,8 @@ void ForLoopGenerator::movePostOpToCurrentAnchor( rewriteOperationAsVectorize(b, groupIdx, &movingOperations); moveOperationsToCurrentForBody(anchorIdx, b, inductionVars, operandIdxMap, - loopState, movingOperations); + loopState, originalOperandLoopArgsMap, + movingOperations); // 2. replace correct for loop result to post-op IRRewriter rewriter(b); @@ -1224,7 +1257,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( DenseMap &originalOperandLoopArgsMap, DenseMap &loopArgsOriginalOperandMap, llvm::SmallVector &nextAnchorResults, - llvm::DenseMap &nextAnchorResultsIdxMap, + DenseMap &nextAnchorResultsIdxMap, llvm::SmallVector &inductionVars, DenseMap &forResultOrignalResultMap, DenseMap &originalResultForResultMap) { @@ -1264,6 +1297,10 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( DenseMap nextAnchorArgsIdxMap; SmallVector nextAnchorArgs; std::queue movedOperation; + DenseMap currentoriginalArgsMap = + originalOperandLoopArgsMap; + DenseMap currentArgsOriginalMap = + loopArgsOriginalOperandMap; DenseMap originalArgsMap, argsOriginalMap; movePreOpToCurrentAnchor( anchorIdx, groupIdx, b, inductionVars, loopState, @@ -1291,7 +1328,8 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( movePostOpToCurrentAnchor( b, anchorIdx, groupIdx, nxtFor->getResults(), b.getBlock(), opQueue, movedOperation, inductionVars, currentLoopStateIdxMap, - loopState, nextAnchorResults, forResultOrignalResultMap); + loopState, currentoriginalArgsMap, currentArgsOriginalMap, + nextAnchorResults, forResultOrignalResultMap); // 4. generate loop results generateLoopResults(b, loc, anchorIdx, groupIdx, nextAnchorResults, @@ -1339,9 +1377,9 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( rewriteOperationAsVectorize(b, groupIdx, &movingOperation); - moveOperationsToCurrentForBody(groupIdx, b, inductionVars, - currentLoopStateIdxMap, loopState, - movingOperation); + moveOperationsToCurrentForBody( + groupIdx, b, inductionVars, currentLoopStateIdxMap, loopState, + originalOperandLoopArgsMap, movingOperation); if (!rdCanonicalizer.getIsEmptyReduction()) { int accValIdx = currentLoopStateIdxMap [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; @@ -1353,7 +1391,9 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( movePostOpToCurrentAnchor( b, anchorIdx, groupIdx, ValueRange(), b.getBlock(), opQueue, movingOperation, inductionVars, currentLoopStateIdxMap, - loopState, nextAnchorResults, forResultOrignalResultMap); + loopState, originalOperandLoopArgsMap, + loopArgsOriginalOperandMap, nextAnchorResults, + forResultOrignalResultMap); nextAnchorResults.clear(); nextAnchorResults.emplace_back(reductionResult); @@ -1384,8 +1424,8 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( return forOp; } -// Generate for loop for parallel axis of `vector.multi_reduction`. -// This function also call reduction axis for loop +/// Generate for loop for parallel axis of `vector.multi_reduction`. +/// This function also call reduction axis for loop scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, DenseMap ¤tLoopStateIdxMap, const ValueRange &initArgs, @@ -1442,6 +1482,10 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( DenseMap nextAnchorArgsIdxMap; SmallVector nextAnchorArgs; std::queue movedQueue; + DenseMap currentOriginalOperandMap = + originalOperandLoopArgsMap; + DenseMap currentOperandOriginalMap = + loopArgsOriginalOperandMap; movePreOpToCurrentAnchor( parallelIdx, groupIdx, b, inductionVars, loopState, currentLoopStateIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, @@ -1511,7 +1555,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( movePostOpToCurrentAnchor( b, parallelIdx, groupIdx, nxtFor->getResults(), nxtFor->getBlock(), opQueue, movedQueue, inductionVars, - currentLoopStateIdxMap, loopState, nextAnchorResults, + currentLoopStateIdxMap, loopState, currentOriginalOperandMap, + currentOperandOriginalMap, nextAnchorResults, forResultOrignalResultMap); // 4. generate loop results @@ -1603,7 +1648,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const int tpSteps, const Location &loc, SmallVector &inductionVars, - const ValueRange &iterArgs) { + const ValueRange &iterArgs, DenseMap &operandIdxMap, + DenseMap &originalOperandMap) { auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; VectorType vtType = tpOp.getVector().getType(); @@ -1661,13 +1707,14 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( writeVars[tpCanonicalizer.getFirstTpIdx()] = inductionVars[tpCanonicalizer.getSecondTpIdx()]; auto writeOp = b.create( - loc, transposeOp->getResults()[0], - successorWriteOp->getOperands()[1], writeVars, inBoundsVal); + loc, transposeOp->getResults()[0], loopState[0], writeVars, + inBoundsVal); maybeYieldValue(b, loc, writeOp->getResults()); } else { // outter loop auto nxtFor = generateTransposeForLoopWithLastDim( - b, grpIdx, forDimIdx + 1, tpSteps, loc, inductionVars, loopState); + b, grpIdx, forDimIdx + 1, tpSteps, loc, inductionVars, loopState, + operandIdxMap, originalOperandMap); maybeYieldValue(b, loc, nxtFor->getResults()); } }); @@ -1800,8 +1847,8 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( } auto writeOp = b.create( - loc, transferReadOp->getResults()[0], - successorWriteOp->getOperands()[1], writeVars, inBoundsVal); + loc, transferReadOp->getResults()[0], loopState[0], writeVars, + inBoundsVal); maybeYieldValue(b, loc, writeOp->getResults()); } else { // outter loop @@ -1935,8 +1982,7 @@ scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( /*inBounds=*/inBoundsVal); auto writeOp = b.create( - loc, transferReadOp->getResults()[0], - successorWriteOp->getOperands()[1], + loc, transferReadOp->getResults()[0], loopState[0], loopType == sourceType ? smallRankShapeVars : inductionVars, inBoundsVal); maybeYieldValue(b, loc, writeOp->getResults()); @@ -2026,7 +2072,6 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { isTwoDTranspose = false; } OpBuilder b(tpOp); - SmallVector iterArgs; vector::TransferWriteOp successorWriteOp; for (Operation *x : tpOp->getUsers()) { if (isa(x)) { @@ -2034,13 +2079,24 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { break; } } - iterArgs.emplace_back(successorWriteOp->getOperands()[1]); + // iterArgs.emplace_back(successorWriteOp->getOperands()[1]); + SmallVector operands; + DenseMap operandIdxMap; + DenseMap originalOperandMap; + auto &initArgs = getGroupOpInitArgs()[grpIdx]; + for (Value x : initArgs) { + operands.emplace_back(x); + operandIdxMap[x] = operands.size() - 1; + originalOperandMap[x] = x; + } + SmallVector iterArgs(operands.begin(), operands.end()); SmallVector inductionVars; IRRewriter rewriter(func); if (permuteSet.contains(rank - 1) and isTwoDTranspose) { scf::ForOp forOp = generateTransposeForLoopWithLastDim( - b, grpIdx, 0, tpStep, tpOp.getLoc(), inductionVars, iterArgs); + b, grpIdx, 0, tpStep, tpOp.getLoc(), inductionVars, iterArgs, + operandIdxMap, originalOperandMap); rewriter.replaceOp(successorWriteOp, forOp); // clear current group operation @@ -2332,13 +2388,15 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. - printGroupOps(getFusionStrategy().getOpGroups()); - std::cout << "___________ before analysis ________________" - << "\n"; + // printGroupOps(getFusionStrategy().getOpGroups()); + // std::cout << "___________ before analysis ________________" + // << "\n"; analysisGroupOperaion(); - std::cout << "___________ after analysis ________________" - << "\n"; - printGroupOps(getFusionStrategy().getOpGroups()); + // std::cout << "___________ after analysis ________________" + // << "\n"; + // printGroupOps(getFusionStrategy().getOpGroups()); + + func->dump(); // Speical Operation Canonicalization canonicalizeSpecialOperation(); @@ -2388,29 +2446,48 @@ void CanonicalizerVectorOperation::run() { return true; } -// +/// void setOperationCorrectOperand( Operation *op, const ValueRange &iterArgs, - const DenseMap &operandIdxMap, ArrayRef inductionVars, + const DenseMap &operandIdxMap, + DenseMap &originalOperandLoopArgsMap, + ArrayRef inductionVars, const DenseMap &opPermuationMap) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { - if (operandIdxMap.contains(opd)) { - op->setOperand(idx, iterArgs[operandIdxMap.at(opd)]); + if (!originalOperandLoopArgsMap.contains(opd)) { + continue; } + Value loopArg = originalOperandLoopArgsMap[opd]; + if (!operandIdxMap.contains(loopArg)) { + continue; + } + op->setOperand(idx, iterArgs[operandIdxMap.at(loopArg)]); } int offset = isa(op) ? 2 : 1; if (dyn_cast(op) || dyn_cast(op)) { - assert(opPermuationMap.contains(op)); auto permutationMap = opPermuationMap.at(op); auto dimExpr = permutationMap.getResults(); for (auto [idx, x] : llvm::enumerate(dimExpr)) { - if (mlir::dyn_cast(x)) { - auto dim = mlir::dyn_cast(x).getPosition(); - op->setOperand(dim + offset, inductionVars[dim]); + + if (!isa(x)) { + llvm::llvm_unreachable_internal( + "Permuatation map must contains dim expr."); + } + + auto dim = dyn_cast(x).getPosition(); + ShapedType tensorType = + cast(op->getOperandTypes()[offset - 1]); + if (tensorType.getRank() > (int64_t)inductionVars.size()) { + int64_t tensorOffset = tensorType.getRank() - inductionVars.size(); + + op->setOperand(dim + offset, inductionVars[dim - tensorOffset]); + continue; } + + op->setOperand(dim + offset, inductionVars[dim]); } } } @@ -2419,7 +2496,12 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, const Location &loc, const ValueRange &iterArgs, VectorType type, const ArrayRef &dims, SmallVector &inductionVars, - const DenseMap &operandIdxMap) { + DenseMap &operandIdxMap, + DenseMap &originalOperandMap, + DenseMap &operandOriginalMap, + SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap) { const int loop_step = getDataTypeValidSteps(type); // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); @@ -2435,22 +2517,56 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( // inner most body of the loop if (forDimIdx == dims.size() - 1) { - moveOperationsToCurrentForBody( - groupIdx, b, inductionVars, operandIdxMap, loopState, - getFusionStrategy().getOpGroups()[groupIdx]); - llvm::MapVector> &resultSet = - getGroupOpResults()[groupIdx]; - SmallVector results(resultSet.size()); - size_t idx = 0; - for (auto itr = resultSet.begin(); itr != resultSet.end(); itr++) { - results[idx++] = itr->first; - } - maybeYieldValue(b, loc, results); + std::queue &opQueue = + getFusionStrategy().getOpGroups()[groupIdx]; + // 1. get operations in current anchor position + std::queue movingOperation; + getOperationInCurrentAnchor(forDimIdx, opQueue, movingOperation); + + // 2. rewrite operation as vectorize IR + rewriteOperationAsVectorize(b, groupIdx, &movingOperation); + + // 3. move opeartions to current for block + moveOperationsToCurrentForBody(groupIdx, b, inductionVars, + operandIdxMap, loopState, + originalOperandMap, movingOperation); + + getResultInCurrentOps(forDimIdx, groupIdx, movingOperation, + nextAnchorResults, forResultOrignalResultMap); + maybeYieldValue(b, loc, nextAnchorResults); } else { // outter loop - auto nxtFor = - constructNestedForOp(forDimIdx + 1, groupIdx, b, loc, loopState, - type, dims, inductionVars, operandIdxMap); + + // 1. move pre-Op to current body + DenseMap nextAnchorArgsIdxMap; + SmallVector nextAnchorArgs; + DenseMap currentOriginalOperandMap = originalOperandMap; + DenseMap currentOperandOriginalMap = operandOriginalMap; + + std::queue movedQueue; + std::queue &opQueue = + getFusionStrategy().getOpGroups()[groupIdx]; + movePreOpToCurrentAnchor( + forDimIdx, groupIdx, b, inductionVars, loopState, operandIdxMap, + nextAnchorArgsIdxMap, nextAnchorArgs, opQueue, movedQueue, + originalOperandMap, operandOriginalMap); + + auto nxtFor = constructNestedForOp( + forDimIdx + 1, groupIdx, b, loc, loopState, type, dims, + inductionVars, nextAnchorArgsIdxMap, originalOperandMap, + operandOriginalMap, nextAnchorResults, nextAnchorResultsIdxMap, + forResultOrignalResultMap); + + movePostOpToCurrentAnchor( + b, forDimIdx, groupIdx, nxtFor->getResults(), b.getBlock(), + opQueue, movedQueue, inductionVars, operandIdxMap, loopState, + currentOriginalOperandMap, currentOperandOriginalMap, + nextAnchorResults, forResultOrignalResultMap); + + generateLoopResults(b, loc, forDimIdx, groupIdx, nextAnchorResults, + nextAnchorResultsIdxMap, nxtFor->getResults(), + movedQueue, forResultOrignalResultMap); + maybeYieldValue(b, loc, nxtFor->getResults()); } }); @@ -2669,10 +2785,7 @@ bool hasDataDependency(Operation *op1, Operation *op2) { if (!isSpecialOp(op1) and !isSpecialOp(op2)) { return false; } - // only special operation may cause data dependency - if (!isSpecialOp(op1)) { - return hasDataDependency(op2, op1); - } + // TODO: Remove this condition to support special operation fusion in the // future if (disableSpecialOp) { @@ -2680,19 +2793,22 @@ bool hasDataDependency(Operation *op1, Operation *op2) { } // if op1 is read the value and pass it to op2, it is not data dependency - if (isa(op2)) { - for (Value opd : op1->getOperands()) { - if (opd.getDefiningOp() == op2) { - return false; - } - } + if (isOperationsHasDefUseRelation(op1, op2)) { + return false; } // if op2 is write the result from op2, it is not data dependency - if (isa(op2)) { - Value opd = op2->getOperand(0); - if (opd.getDefiningOp() == op1) { - return false; - } + if (isOperationsHasDefUseRelation(op1, op2)) { + return false; + } + + // broadcast only fuse with post-op + if (isa(op2)) { + return true; + } + + // only special operation may cause data dependency + if (!isSpecialOp(op1)) { + return hasDataDependency(op2, op1); } auto hasSameAxis = [](const SmallVector &dims1, @@ -2755,22 +2871,15 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return true; } } - return false; } return true; }) .Case([&](vector::BroadcastOp broadcastOp) { - SmallVector dims1, dims2; - getOperationDataAxis(op1, dims1); - getOperationDataAxis(op2, dims2); - return true; - if (!isSpecialOp(op2)) { - return hasSameAxis(dims1, dims2); - } else { - } - return true; + return !OpTrait::util::staticallyKnownBroadcastable( + getOperationVectorType(op1, false)->getShape(), + getOperationVectorType(op2)->getShape()); }) .Case([&](vector::TransposeOp transposeOp) { SmallVector dims1, dims2; @@ -2992,7 +3101,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( setOpVectorizationPermutationMap( transferWriteOp, rewriter, - mlir::dyn_cast( + dyn_cast( transferWriteOp->getResult(0).getType()), transferWriteOp.getPermutationMap()); } @@ -3063,25 +3172,31 @@ void CanonicalizerCommonUsedData::updateOpOperandResultInGroups( while (!tmpOpQueue.empty()) { auto curOp = tmpOpQueue.front(); tmpOpQueue.pop(); - if (curOp == op) { - if (!failed(getOperationVectorType(init.getDefiningOp()))) { - newOpQueue.push(init.getDefiningOp()); - getFusionStrategy().getOpGroupIndexMap()[init.getDefiningOp()] = opGid; - getFusionStrategy().getOpAnchorPos()[init.getDefiningOp()] = - getFusionStrategy().getOpAnchorPos()[op]; - } - - newOpQueue.push(op); - if (result && !failed(getOperationVectorType(result.getDefiningOp()))) { - newOpQueue.push(result.getDefiningOp()); - getFusionStrategy().getOpGroupIndexMap()[result.getDefiningOp()] = - opGid; - getFusionStrategy().getOpAnchorPos()[result.getDefiningOp()] = - getFusionStrategy().getOpGroupIndexMap()[op]; - } - } else { + if (curOp != op) { newOpQueue.push(curOp); + continue; + } + + if (!failed(getOperationVectorType(init.getDefiningOp()))) { + newOpQueue.push(init.getDefiningOp()); + getFusionStrategy().getOpGroupIndexMap()[init.getDefiningOp()] = opGid; + getFusionStrategy().getOpAnchorPos()[init.getDefiningOp()] = + getFusionStrategy().getOpAnchorPos()[op]; + } + // directly use the read operation to do the fusion + if (isa(op) and !tmpOpQueue.empty()) { + IRRewriter rewrite(op); + rewrite.replaceOp(op, op->getOperand(0).getDefiningOp()); + continue; + } + newOpQueue.push(op); + + if (result && !failed(getOperationVectorType(result.getDefiningOp()))) { + newOpQueue.push(result.getDefiningOp()); + getFusionStrategy().getOpGroupIndexMap()[result.getDefiningOp()] = opGid; + getFusionStrategy().getOpAnchorPos()[result.getDefiningOp()] = + getFusionStrategy().getOpGroupIndexMap()[op]; } } getFusionStrategy().getOpGroups()[opGid] = newOpQueue; @@ -3160,19 +3275,54 @@ void VectorOperationAnalyzer::analysisGroupMaxSteps() { } } -void VectorOperationAnalyzer::specialOperationAnchorRectify() { +void VectorOperationAnalyzer::specialOperationRectify( + DenseMap &visitedOperation) { auto &opGroups = getFusionStrategy().getOpGroups(); + IRRewriter rewriter(func); + for (auto [idx, grp] : llvm::enumerate(opGroups)) { std::queue tmpQueue(grp); + std::queue newQueue; while (!tmpQueue.empty()) { auto op = tmpQueue.front(); tmpQueue.pop(); + // remain transfer read operation to do the broadcast fusion + if (isa(op)) { + auto srcOp = op->getOperand(0).getDefiningOp(); + assert(isa(srcOp)); + // just remain write operation, it's size will + // bigger than 1 if not write operation. Because the last operation + // always be write in each group + if (tmpQueue.size() <= 1) { + continue; + } + rewriter.replaceOp(op, srcOp); + continue; + } + // anchor of multidim reduciton rectify if (isa(op)) { auto accSourceOp = op->getOperand(1).getDefiningOp(); getFusionStrategy().getOpAnchorPos()[accSourceOp] = getOperationVectorType(accSourceOp)->getRank() - 1; } + // case: + // %1 = some op + // %2 = tensor.empty() + // %3 = vector.transfer_write %1, %2 + // -> move emtpy operation before %1 for better generate %1 + if (isa(op)) { + auto srcOp = op->getOperand(1).getDefiningOp(); + if (isa_and_nonnull(srcOp)) { + Operation *writeVectorOp = op->getOperands()[0].getDefiningOp(); + if (visitedOperation[srcOp] >= visitedOperation[writeVectorOp]) { + srcOp->moveBefore(writeVectorOp); + visitedOperation[srcOp] = visitedOperation[writeVectorOp]; + } + } + } + newQueue.push(op); } + getFusionStrategy().getOpGroups()[idx] = newQueue; } } @@ -3263,7 +3413,7 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { } else { // if source operation is transfer_read, we need to generate a // same transfer_read operation like source operation. - if (mlir::isa(sourceOp)) { + if (isa(sourceOp)) { auto transferReadOp = cast(sourceOp); auto opInit = canonicalizeCurrentOperation(op, dstRet.value(), idx, &transferReadOp); @@ -3279,35 +3429,70 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { if (!opGroupIndexMap.contains(op)) { continue; } - // TODO: add more operation to this case, write to constant value need + // TODO: add more operation to this case, write a constant value need // to do this if (isa(op)) { if (idx == 0) continue; } + if (isa(op)) { + if (idx == 1) { + // accumulate value, just empty tensor is okay + auto resultTensor = + getOperationResultTensor(sourceOp, visitedOperation); + auto opInit = canonicalizeCurrentOperation(op, resultTensor, idx); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + continue; + } else { + // source operation is the value + llvm::llvm_unreachable_internal( + "Need to add reduce constant operation optimization."); + } + } auto constantOp = cast(sourceOp); IRRewriter rewriter(constantOp); + size_t groupSteps = + getFusionStrategy().getGroupMaxSteps()[opGroupIndexMap[op]]; + if (isa(constantOp.getValue())) { - if (!srcOpCanoniclizedMap.contains(sourceOp)) { - auto [tsr, writeOpresult] = - canonicalizeSourceOperation(sourceOp, visitedOperation); - srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); + VectorType newOperandType = getVectorzedType(op, groupSteps); + auto valueType = dyn_cast(constantOp.getValue()); + if (valueType.isSplat()) { + FailureOr res = createArithSplatConstantOp( + rewriter, constantOp->getLoc(), valueType, newOperandType); + if (failed(res)) { + llvm::llvm_unreachable_internal("Wrong to create constant op."); + } + op->setOperand(idx, res.value()); + } else { + // TODO: need to test not splat value + // newConstantOp = srcWriter.create( + // srcOp->getLoc(), srcConstantOp.getValue()); + // if (!srcOpCanoniclizedMap.contains(sourceOp)) { + // auto [tsr, writeOpresult] = + // canonicalizeSourceOperation(sourceOp, visitedOperation); + // srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); + // } + // auto opInit = canonicalizeCurrentOperation( + // op, srcOpCanoniclizedMap[sourceOp].second, idx); + // updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + llvm::llvm_unreachable_internal( + "Can't support not splat constant value."); } - auto opInit = canonicalizeCurrentOperation( - op, srcOpCanoniclizedMap[sourceOp].second, idx); - updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); } } } - if (mlir::isa(op) && !movedOperationSet.contains(op)) { - auto parentBlock = op->getBlock(); - op->moveBefore(parentBlock, parentBlock->getOperations().begin()); - movedOperationSet.insert(op); - } + // if (mlir::isa(op) && !movedOperationSet.contains(op)) { + // auto parentBlock = op->getBlock(); + // std::stack opStack; + + // op->moveBefore(parentBlock, parentBlock->getOperations().begin()); + // movedOperationSet.insert(op); + // } }); analysisEmptyGroup(); - specialOperationAnchorRectify(); + specialOperationRectify(visitedOperation); #undef RESULT_RETURN_TYPE LDBG("Complete analysis group operation results\n"); } @@ -3320,14 +3505,18 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( const size_t groupId, IRRewriter &rewriter, VectorType vectorType) { auto &resultSet = getGroupOpResults(); - auto &initArgs = getGroupOpInitArgs()[groupId]; assert(!resultSet.empty() && "Expected non-empty value"); // prepare for loop iterargs - SmallVector operands; - DenseMap operandIdxMap; - for (auto [idx, x] : llvm::enumerate(initArgs)) { + SmallVector operands, nextLoopResults; + DenseMap operandIdxMap, resultIdxMap; + DenseMap originalOperandMap, operandOriginalMap, + forResultOrignalResultMap; + auto &initArgs = getGroupOpInitArgs()[groupId]; + for (Value x : initArgs) { operands.emplace_back(x); operandIdxMap[x] = operands.size() - 1; + originalOperandMap[x] = x; + operandOriginalMap[x] = x; } ValueRange forIterArgs(operands); ArrayRef shapes = vectorType.getShape(); @@ -3335,7 +3524,19 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( // generate for loop auto forOp = constructNestedForOp( 0, groupId, rewriter, rewriter.getUnknownLoc(), forIterArgs, vectorType, - shapes, inductionVars, operandIdxMap); + shapes, inductionVars, operandIdxMap, originalOperandMap, + operandOriginalMap, nextLoopResults, resultIdxMap, + forResultOrignalResultMap); + auto replaceIfFn = [&](OpOperand &use) { + return use.getOwner()->getBlock() == forOp->getBlock(); + }; + for (auto x : nextLoopResults) { + auto originalResult = forResultOrignalResultMap[x]; + rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), + forOp->getResults()[resultIdxMap[x]], + replaceIfFn); + } + return forOp; } diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index 749facc6b..3622d6dc0 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -6,6 +6,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -42,6 +43,14 @@ namespace { #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") +#define IMPLEMENTED_MATMUL \ + linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp, \ + linalg::BatchReduceMatmulOp, linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp, \ + linalg::MatmulOp, linalg::BatchMatmulOp, \ + linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ + linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ + linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp + bool is_innermost_ir(Operation *op) { bool inner_most = true; op->walk([&inner_most](Operation *p) { @@ -54,6 +63,24 @@ bool is_innermost_ir(Operation *op) { return inner_most; } +static bool isMatchedOperationUsage(Operation *op) { + if (isa(op)) { + return true; + } + // operation produce for matmul can't lower + if (!isa(op)) { + return false; + } + + for (auto x : op->getUsers()) { + if (isa(x)) { + return true; + } + } + + return false; +} + /// Need to check if the reassociation are static/constant. LogicalResult lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp) { // @@ -425,6 +452,13 @@ struct OperationConvertTileVectorPass : public RewritePattern { if (!targetOp || !is_innermost_ir(op)) return rewriter.notifyMatchFailure(op, "Not expected operations."); + // linalg.fill + linalgx.batch_mutmul should not be lower to vector + // because these two operation is needed by brgemm optimization. + if (isMatchedOperationUsage(op)) { + return rewriter.notifyMatchFailure( + op, "linalg.fill + linalgx.batch_matmul can't do lowering."); + } + return linalg::vectorize(rewriter, op, /*inputVectorSizes=*/{}, /*scalableVecDims=*/{}, vectorizeNDExtract, flatten1DDepthwiseConv); @@ -511,6 +545,7 @@ struct LowerTileVectorPass vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + // clean up useless IR auto curOp = getOperation(); IRRewriter reWriter(curOp); DominanceInfo domInfo(curOp); diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 29a143835..9a2805668 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -53,6 +54,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // scf + arith + math + vector + tensor + linalg.brgemm void populateVectorPasses(mlir::OpPassManager &pm) { + + pm.addPass(createLowerToTileVector()); // Do promotion for math / arith ops pm.addNestedPass(math::createMathLegalizeToF32()); // sourceTypeStrs can be extended @@ -65,7 +68,8 @@ void populateVectorPasses(mlir::OpPassManager &pm) { // Bf16 cast elimilation pass pm.addNestedPass(mlir::createCanonicalizerPass()); // oneDNN graph spec - pm.addNestedPass(arith::createArithExpandOpsPass()); + // pm.addNestedPass(arith::createArithExpandOpsPass()); + pm.addNestedPass(createCPUPhysicalRegisterPass()); // todo: lower to physical vector pass, device dependent pass } diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index b09ff01b3..6eb89574a 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --split-input-file --lower-to-tile-vector --CPU-physical-register-pass | FileCheck %s +// RUN: gc-opt %s --split-input-file --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all | FileCheck %s // CHECK-LABEL: func @add_tensor_test0 func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { @@ -22,7 +22,21 @@ func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4 return %2 : tensor<11008x4096xf32> } -// CHECK-LABEL: func @fc_relu +func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { + %0 = tensor.empty() : tensor<16x64xf32> + %reduce = linalg.reduce + ins(%arg0:tensor<16x32x64xf32>) + outs(%0:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %1 = arith.addf %out, %in: f32 + linalg.yield %1: f32 + } + %2 = tensor.expand_shape %reduce [[0],[1, 2]] output_shape [16, 1, 64] : tensor<16x64xf32> into tensor<16x1x64xf32> + return %2 : tensor<16x1x64xf32> +} + +// // CHECK-LABEL: func @fc_relu func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) -> tensor<512x512xf32> { @@ -146,7 +160,7 @@ func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf #map2 = affine_map<(d0) -> (d0 * 4)> #map3 = affine_map<(d0) -> (d0 floordiv 16)> #map4 = affine_map<(d0) -> (d0 floordiv 32)> - func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { +func.func @fuse_mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { %c32 = arith.constant 32 : index %c512 = arith.constant 512 : index %c128 = arith.constant 128 : index @@ -162,207 +176,68 @@ func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf %5 = affine.apply #map2(%arg4) %extracted_slice_0 = tensor.extract_slice %arg1[0, %5, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x4x16x32xbf16> %extracted_slice_1 = tensor.extract_slice %1[0, %4] [512, 128] [1, 1] : tensor<512x256xbf16> to tensor<512x128xbf16> - %6 = affine.apply #map(%arg3) - %7 = affine.apply #map1(%arg4) - %extracted_slice_2 = tensor.extract_slice %arg5[%6, %7] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %8 = affine.apply #map(%arg3) - %9 = affine.apply #map1(%arg4) - %10 = affine.apply #map(%arg3) - %11 = affine.apply #map1(%arg4) - %12 = affine.apply #map1(%arg4) - %13 = affine.apply #map(%arg3) - %14 = affine.apply #map1(%arg4) - %extracted_slice_3 = tensor.extract_slice %arg2[%12] [128] [1] : tensor<256xbf16> to tensor<128xbf16> - %extracted_slice_4 = tensor.extract_slice %0[%13, %14] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %extracted_slice_5 = tensor.extract_slice %arg6[%10, %11] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %15 = affine.apply #map(%arg3) - %16 = affine.apply #map1(%arg4) - %17 = affine.apply #map(%arg3) - %18 = affine.apply #map1(%arg4) - %extracted_slice_6 = tensor.extract_slice %arg7[%17, %18] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %19:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_2, %arg10 = %extracted_slice_5, %arg11 = %extracted_slice_6) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %22:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %23:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_2 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_3 = tensor.extract_slice %arg2[%4] [128] [1] : tensor<256xbf16> to tensor<128xbf16> + %extracted_slice_4 = tensor.extract_slice %0[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_5 = tensor.extract_slice %arg6[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_6 = tensor.extract_slice %arg7[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %6:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_2, %arg10 = %extracted_slice_5, %arg11 = %extracted_slice_6) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %7:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %8:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { %extracted_slice_7 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> - %24 = affine.apply #map3(%arg16) - %25 = affine.apply #map4(%arg12) - %extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%24, %25, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x4x16x32xbf16> + %9 = affine.apply #map3(%arg16) + %10 = affine.apply #map4(%arg12) + %extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%9, %10, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x4x16x32xbf16> %extracted_slice_9 = tensor.extract_slice %extracted_slice_1[%arg16, %arg12] [512, 128] [1, 1] : tensor<512x128xbf16> to tensor<512x128xbf16> %extracted_slice_10 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> %extracted_slice_11 = tensor.extract_slice %extracted_slice_3[%arg12] [128] [1] : tensor<128xbf16> to tensor<128xbf16> %extracted_slice_12 = tensor.extract_slice %extracted_slice_4[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> %extracted_slice_13 = tensor.extract_slice %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> %extracted_slice_14 = tensor.extract_slice %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> - %26:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_10, %arg22 = %extracted_slice_13, %arg23 = %extracted_slice_14) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %27:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %28:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %11:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_10, %arg22 = %extracted_slice_13, %arg23 = %extracted_slice_14) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %12:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %13:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { %extracted_slice_17 = tensor.extract_slice %extracted_slice_7[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> - %29 = affine.apply #map3(%arg28) - %30 = affine.apply #map4(%arg24) - %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%29, %30, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> + %14 = affine.apply #map3(%arg28) + %15 = affine.apply #map4(%arg24) + %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%14, %15, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> %extracted_slice_19 = tensor.extract_slice %extracted_slice_9[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> %unpack = tensor.unpack %extracted_slice_18 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_19 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> %extracted_slice_20 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %31 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %16 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> %expanded = tensor.expand_shape %extracted_slice_17 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> %expanded_21 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> - %32 = linalg.batch_reduce_matmul ins(%expanded, %expanded_21 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%31 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %17 = linalg.batch_reduce_matmul ins(%expanded, %expanded_21 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%16 : tensor<32x32xbf16>) -> tensor<32x32xbf16> %extracted_slice_22 = tensor.extract_slice %extracted_slice_11[%arg24] [32] [1] : tensor<128xbf16> to tensor<32xbf16> %extracted_slice_23 = tensor.extract_slice %extracted_slice_12[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> %broadcasted = linalg.broadcast ins(%extracted_slice_22 : tensor<32xbf16>) outs(%extracted_slice_23 : tensor<32x32xbf16>) dimensions = [0] %extracted_slice_24 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %33 = linalg.add ins(%32, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_24 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_25 = tensor.insert_slice %32 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %18 = linalg.add ins(%17, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_24 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_25 = tensor.insert_slice %17 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> %extracted_slice_26 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %34 = linalg.exp ins(%33 : tensor<32x32xbf16>) outs(%extracted_slice_26 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_27 = tensor.insert_slice %33 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %inserted_slice_28 = tensor.insert_slice %34 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %19 = linalg.exp ins(%18 : tensor<32x32xbf16>) outs(%extracted_slice_26 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_27 = tensor.insert_slice %18 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %inserted_slice_28 = tensor.insert_slice %19 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> scf.yield %inserted_slice_25, %inserted_slice_27, %inserted_slice_28 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - scf.yield %28#0, %28#1, %28#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %13#0, %13#1, %13#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - scf.yield %27#0, %27#1, %27#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %12#0, %12#1, %12#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - %inserted_slice = tensor.insert_slice %26#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_15 = tensor.insert_slice %26#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_16 = tensor.insert_slice %26#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice = tensor.insert_slice %11#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_15 = tensor.insert_slice %11#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_16 = tensor.insert_slice %11#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> scf.yield %inserted_slice, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - scf.yield %23#0, %23#1, %23#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %8#0, %8#1, %8#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - scf.yield %22#0, %22#1, %22#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> } - %20 = affine.apply #map(%arg3) - %21 = affine.apply #map1(%arg4) scf.forall.in_parallel { - tensor.parallel_insert_slice %19#2 into %arg7[%20, %21] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %19#1 into %arg6[%15, %16] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %19#0 into %arg5[%8, %9] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %6#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %6#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %6#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> } } return %2#2 : tensor<128x256xbf16> } - - - - func.func @main_entry(%arg0: tensor<128x128x64x64xbf16>, %arg1: tensor<128x128x32x64x2xbf16>) -> tensor<128x128x64x64xbf16> attributes {llvm.emit_c_interface} { - %c2 = arith.constant 2 : index - %c128 = arith.constant 128 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c64 = arith.constant 64 : index - %c32 = arith.constant 32 : index - %c4 = arith.constant 4 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x128x64x64xbf16> - %1 = tensor.empty() : tensor<128x128x64x64xf32> - %2 = tensor.empty() : tensor<2x1x1x128x128x64x64xf32> - %3 = scf.forall (%arg2) in (2) shared_outs(%arg3 = %2) -> (tensor<2x1x1x128x128x64x64xf32>) { - %extracted_slice = tensor.extract_slice %arg3[%arg2, 0, 0, 0, 0, 0, 0] [1, 1, 1, 128, 128, 64, 64] [1, 1, 1, 1, 1, 1, 1] : tensor<2x1x1x128x128x64x64xf32> to tensor<128x128x64x64xf32> - %5 = scf.forall (%arg4) in (7) shared_outs(%arg5 = %extracted_slice) -> (tensor<128x128x64x64xf32>) { - %6 = affine.min affine_map<(d0) -> (d0 * -19 + 128, 19)>(%arg4) - %7 = affine.max affine_map<(d0) -> (0, d0)>(%6) - %8 = affine.apply affine_map<(d0) -> (d0 * 19)>(%arg4) - %extracted_slice_0 = tensor.extract_slice %arg5[%8, 0, 0, 0] [%7, 128, 64, 64] [1, 1, 1, 1] : tensor<128x128x64x64xf32> to tensor - %9 = scf.forall (%arg6) in (4) shared_outs(%arg7 = %extracted_slice_0) -> (tensor) { - %11 = affine.max affine_map<(d0) -> (0, d0)>(%6) - %12 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6) - %extracted_slice_1 = tensor.extract_slice %arg7[0, %12, 0, 0] [%11, 32, 64, 64] [1, 1, 1, 1] : tensor to tensor - %13 = scf.for %arg8 = %c0 to %11 step %c4 iter_args(%arg9 = %extracted_slice_1) -> (tensor) { - %15 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%arg8)[%11] - %extracted_slice_2 = tensor.extract_slice %arg9[%arg8, 0, 0, 0] [%15, 32, 64, 64] [1, 1, 1, 1] : tensor to tensor - %16 = scf.for %arg10 = %c0 to %c32 step %c4 iter_args(%arg11 = %extracted_slice_2) -> (tensor) { - %extracted_slice_3 = tensor.extract_slice %arg11[0, %arg10, 0, 0] [%15, 4, 64, 64] [1, 1, 1, 1] : tensor to tensor - %17 = scf.for %arg12 = %c0 to %c64 step %c16 iter_args(%arg13 = %extracted_slice_3) -> (tensor) { - %extracted_slice_5 = tensor.extract_slice %arg13[0, 0, 0, 0] [%15, 4, 64, 64] [1, 1, 1, 1] : tensor to tensor - %18 = scf.for %arg14 = %c0 to %15 step %c1 iter_args(%arg15 = %extracted_slice_5) -> (tensor) { - %extracted_slice_7 = tensor.extract_slice %arg15[%arg14, 0, 0, 0] [1, 4, 64, 64] [1, 1, 1, 1] : tensor to tensor<1x4x64x64xf32> - %19 = scf.for %arg16 = %c0 to %c4 step %c1 iter_args(%arg17 = %extracted_slice_7) -> (tensor<1x4x64x64xf32>) { - %20 = affine.apply affine_map<(d0)[s0, s1] -> (d0 * 19 + s0 + s1)>(%arg4)[%arg14, %arg8] - %21 = affine.apply affine_map<(d0)[s0] -> (d0 * 64 + s0)>(%arg2)[%arg12] - %extracted_slice_9 = tensor.extract_slice %arg0[%20, %21, 0, 0] [1, 16, 64, 64] [1, 1, 1, 1] : tensor<128x128x64x64xbf16> to tensor<16x64x64xbf16> - %22 = affine.apply affine_map<(d0)[s0, s1] -> (d0 * 32 + s0 + s1)>(%arg6)[%arg16, %arg10] - %23 = affine.apply affine_map<(d0)[s0] -> (d0 * 64 + s0)>(%arg2)[%arg12] - %extracted_slice_10 = tensor.extract_slice %arg1[%22, %23, 0, 0, 0] [1, 16, 32, 64, 2] [1, 1, 1, 1, 1] : tensor<128x128x32x64x2xbf16> to tensor<16x32x64x2xbf16> - %extracted_slice_11 = tensor.extract_slice %arg17[0, %arg16, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x4x64x64xf32> to tensor<64x64xf32> - %24 = arith.cmpi eq, %arg12, %c0 : index - %25 = scf.if %24 -> (tensor<64x64xf32>) { - %26 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_11 : tensor<64x64xf32>) -> tensor<64x64xf32> - %27 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_9, %extracted_slice_10 : tensor<16x64x64xbf16>, tensor<16x32x64x2xbf16>) outs(%26 : tensor<64x64xf32>) -> tensor<64x64xf32> - scf.yield %27 : tensor<64x64xf32> - } else { - %26 = linalgx.batch_reduce_matmul_vnni ins(%extracted_slice_9, %extracted_slice_10 : tensor<16x64x64xbf16>, tensor<16x32x64x2xbf16>) outs(%extracted_slice_11 : tensor<64x64xf32>) -> tensor<64x64xf32> - scf.yield %26 : tensor<64x64xf32> - } - %inserted_slice_12 = tensor.insert_slice %25 into %arg17[0, %arg16, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<64x64xf32> into tensor<1x4x64x64xf32> - scf.yield %inserted_slice_12 : tensor<1x4x64x64xf32> - } - %inserted_slice_8 = tensor.insert_slice %19 into %arg15[%arg14, 0, 0, 0] [1, 4, 64, 64] [1, 1, 1, 1] : tensor<1x4x64x64xf32> into tensor - scf.yield %inserted_slice_8 : tensor - } - %inserted_slice_6 = tensor.insert_slice %18 into %arg13[0, 0, 0, 0] [%15, 4, 64, 64] [1, 1, 1, 1] : tensor into tensor - scf.yield %inserted_slice_6 : tensor - } - %inserted_slice_4 = tensor.insert_slice %17 into %arg11[0, %arg10, 0, 0] [%15, 4, 64, 64] [1, 1, 1, 1] : tensor into tensor - scf.yield %inserted_slice_4 : tensor - } - %inserted_slice = tensor.insert_slice %16 into %arg9[%arg8, 0, 0, 0] [%15, 32, 64, 64] [1, 1, 1, 1] : tensor into tensor - scf.yield %inserted_slice : tensor - } - %14 = affine.apply affine_map<(d0) -> (d0 * 32)>(%arg6) - scf.forall.in_parallel { - tensor.parallel_insert_slice %13 into %arg7[0, %14, 0, 0] [%11, 32, 64, 64] [1, 1, 1, 1] : tensor into tensor - } - } - %10 = affine.apply affine_map<(d0) -> (d0 * 19)>(%arg4) - scf.forall.in_parallel { - tensor.parallel_insert_slice %9 into %arg5[%10, 0, 0, 0] [%7, 128, 64, 64] [1, 1, 1, 1] : tensor into tensor<128x128x64x64xf32> - } - } - scf.forall.in_parallel { - tensor.parallel_insert_slice %5 into %arg3[%arg2, 0, 0, 0, 0, 0, 0] [1, 1, 1, 128, 128, 64, 64] [1, 1, 1, 1, 1, 1, 1] : tensor<128x128x64x64xf32> into tensor<2x1x1x128x128x64x64xf32> - } - } - %4 = scf.forall (%arg2) in (128) shared_outs(%arg3 = %0) -> (tensor<128x128x64x64xbf16>) { - %extracted_slice = tensor.extract_slice %1[%arg2, 0, 0, 0] [1, 128, 64, 64] [1, 1, 1, 1] : tensor<128x128x64x64xf32> to tensor<1x128x64x64xf32> - %extracted_slice_0 = tensor.extract_slice %arg3[%arg2, 0, 0, 0] [1, 128, 64, 64] [1, 1, 1, 1] : tensor<128x128x64x64xbf16> to tensor<1x128x64x64xbf16> - %5:2 = scf.for %arg4 = %c0 to %c128 step %c1 iter_args(%arg5 = %extracted_slice, %arg6 = %extracted_slice_0) -> (tensor<1x128x64x64xf32>, tensor<1x128x64x64xbf16>) { - %extracted_slice_1 = tensor.extract_slice %arg5[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x128x64x64xf32> to tensor<1x1x64x64xf32> - %extracted_slice_2 = tensor.extract_slice %arg6[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x128x64x64xbf16> to tensor<1x1x64x64xbf16> - %6:2 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %extracted_slice_1, %arg9 = %extracted_slice_2) -> (tensor<1x1x64x64xf32>, tensor<1x1x64x64xbf16>) { - %extracted_slice_4 = tensor.extract_slice %arg8[0, 0, %arg7, 0] [1, 1, 1, 64] [1, 1, 1, 1] : tensor<1x1x64x64xf32> to tensor<1x1x1x64xf32> - %extracted_slice_5 = tensor.extract_slice %arg9[0, 0, %arg7, 0] [1, 1, 1, 64] [1, 1, 1, 1] : tensor<1x1x64x64xbf16> to tensor<1x1x1x64xbf16> - %7:2 = scf.for %arg10 = %c0 to %c64 step %c32 iter_args(%arg11 = %extracted_slice_4, %arg12 = %extracted_slice_5) -> (tensor<1x1x1x64xf32>, tensor<1x1x1x64xbf16>) { - %extracted_slice_8 = tensor.extract_slice %arg11[0, 0, 0, %arg10] [1, 1, 1, 32] [1, 1, 1, 1] : tensor<1x1x1x64xf32> to tensor<1x1x1x32xf32> - %8 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_8 : tensor<1x1x1x32xf32>) -> tensor<1x1x1x32xf32> - %9 = scf.for %arg13 = %c0 to %c2 step %c1 iter_args(%arg14 = %8) -> (tensor<1x1x1x32xf32>) { - %extracted_slice_12 = tensor.extract_slice %3[%arg13, 0, 0, %arg2, %arg4, %arg7, %arg10] [1, 1, 1, 1, 1, 1, 32] [1, 1, 1, 1, 1, 1, 1] : tensor<2x1x1x128x128x64x64xf32> to tensor<1x1x1x1x1x1x32xf32> - %reduced = linalg.reduce ins(%extracted_slice_12 : tensor<1x1x1x1x1x1x32xf32>) outs(%arg14 : tensor<1x1x1x32xf32>) dimensions = [0, 1, 2] - (%in: f32, %init: f32) { - %11 = arith.addf %in, %init : f32 - linalg.yield %11 : f32 - } - scf.yield %reduced : tensor<1x1x1x32xf32> - } - %extracted_slice_9 = tensor.extract_slice %arg12[0, 0, 0, %arg10] [1, 1, 1, 32] [1, 1, 1, 1] : tensor<1x1x1x64xbf16> to tensor<1x1x1x32xbf16> - %10 = linalg.copy ins(%9 : tensor<1x1x1x32xf32>) outs(%extracted_slice_9 : tensor<1x1x1x32xbf16>) -> tensor<1x1x1x32xbf16> - %inserted_slice_10 = tensor.insert_slice %9 into %arg11[0, 0, 0, %arg10] [1, 1, 1, 32] [1, 1, 1, 1] : tensor<1x1x1x32xf32> into tensor<1x1x1x64xf32> - %inserted_slice_11 = tensor.insert_slice %10 into %arg12[0, 0, 0, %arg10] [1, 1, 1, 32] [1, 1, 1, 1] : tensor<1x1x1x32xbf16> into tensor<1x1x1x64xbf16> - scf.yield %inserted_slice_10, %inserted_slice_11 : tensor<1x1x1x64xf32>, tensor<1x1x1x64xbf16> - } - %inserted_slice_6 = tensor.insert_slice %7#0 into %arg8[0, 0, %arg7, 0] [1, 1, 1, 64] [1, 1, 1, 1] : tensor<1x1x1x64xf32> into tensor<1x1x64x64xf32> - %inserted_slice_7 = tensor.insert_slice %7#1 into %arg9[0, 0, %arg7, 0] [1, 1, 1, 64] [1, 1, 1, 1] : tensor<1x1x1x64xbf16> into tensor<1x1x64x64xbf16> - scf.yield %inserted_slice_6, %inserted_slice_7 : tensor<1x1x64x64xf32>, tensor<1x1x64x64xbf16> - } - %inserted_slice = tensor.insert_slice %6#0 into %arg5[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x1x64x64xf32> into tensor<1x128x64x64xf32> - %inserted_slice_3 = tensor.insert_slice %6#1 into %arg6[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : tensor<1x1x64x64xbf16> into tensor<1x128x64x64xbf16> - scf.yield %inserted_slice, %inserted_slice_3 : tensor<1x128x64x64xf32>, tensor<1x128x64x64xbf16> - } - scf.forall.in_parallel { - tensor.parallel_insert_slice %5#1 into %arg3[%arg2, 0, 0, 0] [1, 128, 64, 64] [1, 1, 1, 1] : tensor<1x128x64x64xbf16> into tensor<128x128x64x64xbf16> - } - } - return %4 : tensor<128x128x64x64xbf16> - } - From 39524f37da84c594f48c8145d63d3ba6aa7fe218 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 1 Aug 2024 17:07:02 +0800 Subject: [PATCH 21/66] fix wrong permutation map due to community pass greedy fold bug --- lib/gc/Transforms/LowerTileVectorPass.cpp | 24 +++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index 3622d6dc0..0049ed82c 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -51,6 +51,10 @@ namespace { linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp +#define SUPPORT_TENSOR_OP \ + tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::BitcastOp, \ + tensor::ConcatOp + bool is_innermost_ir(Operation *op) { bool inner_most = true; op->walk([&inner_most](Operation *p) { @@ -429,10 +433,7 @@ LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op) { } bool is_required_tensorOp(Operation *operation) { - return llvm::isa(operation) || - llvm::isa(operation) || - llvm::isa(operation) || - llvm::isa(operation); + return isa(operation); } template @@ -526,12 +527,16 @@ struct LowerTileVectorPass // auto *ctx = &getContext(); RewritePatternSet patterns(ctx); + auto funcOp = getOperation(); tensor::ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); return producer && producer->hasOneUse(); }; + // some operation convert as constant, this pattern can help us to improve + // the performance tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn); + // remove unnessary operation tensor::populateReassociativeReshapeFoldingPatterns(patterns); tensor::populateFoldTensorSubsetOpPatterns(patterns); tensor::populateFoldTensorEmptyPatterns(patterns, true); @@ -539,17 +544,16 @@ struct LowerTileVectorPass populateLowerToTileVectorPatterns(patterns); linalg::populatePadOpVectorizationPatterns(patterns); + // ensure read and write on last dimension vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + // remove unnessary broadcast operation vector::populateSinkVectorBroadcastPatterns(patterns); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - // clean up useless IR - auto curOp = getOperation(); - IRRewriter reWriter(curOp); - DominanceInfo domInfo(curOp); - eliminateCommonSubExpressions(reWriter, domInfo, curOp); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns), config); } }; } // namespace From ca28f7cb50b495408e33257d483880c1ba731b30 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Sat, 3 Aug 2024 12:39:15 +0800 Subject: [PATCH 22/66] fix reduce bugs --- include/gc/Transforms/Passes.td | 2 +- include/gc/Transforms/TilingVector.h | 112 +++- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 581 +++++++++++++----- lib/gc/Transforms/LowerTileVectorPass.cpp | 64 +- 4 files changed, 566 insertions(+), 193 deletions(-) diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index e2a64e7c0..66dbf4772 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -58,7 +58,7 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { ]; } -def LowerToTileVector : Pass<"lower-to-tile-vector"> { +def LowerToTileVector : Pass<"lower-to-tile-vector", "func::FuncOp"> { let summary = "Lower tensor to tile vector."; let description = [{ Lower tensor to tile vector form. diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index b527a73dd..ca28afbe8 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -214,13 +214,13 @@ class MultiReductionCanonicalizer bool getIsEmptyReduction() { return isEmptyReduction; } void initReductionAxis(); void initParallelAxis(); - llvm::SmallVector &getReductionAxis() { return reductionAxis; }; - llvm::SmallVector &getParallelAxis() { return parallelAxis; }; + SmallVector &getReductionAxis() { return reductionAxis; }; + SmallVector &getParallelAxis() { return parallelAxis; }; std::queue &getPrevOps() { return prevOps; } std::queue &getPostOps() { return postOps; } std::queue &getAccRelatedOps() { return accRelatedOps; } std::queue &getSourceRelatedOps() { return sourceRelatedOps; } - llvm::SetVector &getOriginalOpResults() { return originalOpResults; } + SetVector &getOriginalOpResults() { return originalOpResults; } VectorType getSourceType() { return sourceType; }; VectorType getAccType() { return accType; }; llvm::SmallDenseMap &getResultIdxMap() { return resultIdxMap; } @@ -266,10 +266,14 @@ class TransposeCanonicalizer static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose; } + enum TRANSPOSE_KERNEL { + KERNEL_16X16 = 16, + }; size_t getFirstTpIdx() { return firstTpIdx; } size_t getSecondTpIdx() { return secondTpIdx; } bool isTwoDTranspose(); + bool isTransposeOnAllOneDim(); }; class ShapeCastCanonicalizer @@ -367,7 +371,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { return groupOpResults; } - llvm::SmallVector, 8> &getGroupOpInitArgs() { + SmallVector, 8> &getGroupOpInitArgs() { return groupOpInitArgs; } @@ -402,8 +406,11 @@ class CanonicalizerCommonUsedData : public TypeHelper { size_t anchorPos, ReturnTypeKind retKind, DenseMap &visitedOperation); - void updateOpOperandResultInGroups(size_t opGid, Operation *op, Value &init, + void updateOpOperandResultInGroups(size_t opGid, Operation *op, + const Value &init = Value(), const Value &result = Value()); + void removeOpInCurrentGroups(size_t grpIdx, Operation *op); + void updateOpGroupInfo(size_t grpIdx); Value canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, @@ -413,6 +420,10 @@ class CanonicalizerCommonUsedData : public TypeHelper { Operation * createTransferReadOpBefore(Operation *op, const Value &operand, vector::TransferReadOp *srcReadOp = nullptr); + /// get next operation in current operation group + template + Operation *getNextTargetOperationInCurrentGroup(Operation *curOp, + const size_t grpIdx); }; class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { @@ -427,6 +438,13 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { void setGeneratorFunc(func::FuncOp &func) { this->func = func; } void clearCurrentOperationGroup(size_t grpIdx); void generateGroupOpVectorizedIR(const int idx); + + /// mark which operation need to set correct for loop var idx + /// due to sometimes we need to chage for loop order like reduce operation. + void getCurrentGroupIndiceLoopMap( + DenseMap> &indiceLoopMap, + const size_t groupId, Operation *op, + const DenseMap &setIdxMap = DenseMap({})); void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, const std::queue *queue = nullptr); @@ -438,29 +456,41 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { generateVectorizedForLoop(const size_t groupId, IRRewriter &rewriter, const VectorType vectorType); - scf::ForOp - constructNestedForOp(const size_t forDimIdx, const size_t groupIdx, - OpBuilder &b, const Location &loc, - const ValueRange &iterArgs, VectorType type, - const llvm::ArrayRef &dims, - llvm::SmallVector &inductionVars, - llvm::DenseMap &operandIdxMap, - DenseMap &originalOperandMap, - DenseMap &operandOriginalMap, - llvm::SmallVector &nextAnchorResults, - DenseMap &nextAnchorResultsIdxMap, - DenseMap &forResultOrignalResultMap); + scf::ForOp constructNestedForOp( + const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, + const Location &loc, const ValueRange &iterArgs, + const llvm::ArrayRef &dims, + llvm::SmallVector &inductionVars, + llvm::DenseMap &operandIdxMap, + DenseMap &originalOperandMap, + DenseMap &operandOriginalMap, + llvm::SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap, + DenseMap> &indiceLoopMap); + void moveOperationsToCurrentForBody( const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, const llvm::DenseMap &operandIdxMap, const ValueRange &loopState, DenseMap &originalOperandLoopArgsMap, - std::queue &queue); + std::queue &queue, + DenseMap> &indiceLoopMap); + void setOperationCorrectOperand( + Operation *op, const ValueRange &iterArgs, + const DenseMap &operandIdxMap, + DenseMap &originalOperandLoopArgsMap, + ArrayRef inductionVars, + const DenseMap &opPermuationMap, + DenseMap> &indiceLoopMap); void getResultInCurrentOps(const size_t anchorIdx, const size_t groupId, const std::queue ops, SmallVector &results, + DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap); + + /// todo: need to add a struct to remove so many parameters void getInitArgsToNextAnchor(const size_t anchorIdx, const size_t groupId, const std::queue &nextOperations, @@ -482,6 +512,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const std::queue &movedOperaiton, DenseMap &forResultOrignalResultMap); + /// todo: need to add a struct to remove so many parameters void movePostOpToCurrentAnchor( OpBuilder &b, const int anchorIdx, const int groupIdx, const ValueRange &forResults, const Block *forBlock, @@ -492,19 +523,20 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { DenseMap &originalOperandLoopArgsMap, DenseMap &loopArgsOriginalOperandMap, const llvm::SmallVector &nextAnchorResults, - DenseMap &forResultOrignalResultMap); + DenseMap &forResultOrignalResultMap, + DenseMap> &indiceLoopMap); - void - movePreOpToCurrentAnchor(const size_t anchorIdx, const size_t groupIdx, - OpBuilder &b, ArrayRef inductionVars, - const ValueRange &loopState, - llvm::DenseMap ¤tLoopStateIdxMap, - llvm::DenseMap &nextLoopStateIdxMap, - llvm::SmallVector &nextAnchorArgs, - std::queue &candidateQueue, - std::queue &movedQueue, - DenseMap &originalOperandLoopArgsMap, - DenseMap &LoopArgsoriginalOperandMap); + void movePreOpToCurrentAnchor( + const size_t anchorIdx, const size_t groupIdx, OpBuilder &b, + ArrayRef inductionVars, const ValueRange &loopState, + llvm::DenseMap ¤tLoopStateIdxMap, + llvm::DenseMap &nextLoopStateIdxMap, + llvm::SmallVector &nextAnchorArgs, + std::queue &candidateQueue, + std::queue &movedQueue, + DenseMap &originalOperandLoopArgsMap, + DenseMap &LoopArgsoriginalOperandMap, + DenseMap> &indiceLoopMap); void replaceOperationsWithForLoopResult( IRRewriter &rewrite, const ValueRange &forResults, const Block *forBlock, @@ -523,7 +555,8 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { llvm::DenseMap &nextAnchorResultsIdxMap, llvm::SmallVector &inductionVars, DenseMap &forResultOrignalResultMap, - DenseMap &originalResultForResultMap); + DenseMap &originalResultForResultMap, + DenseMap> &indiceLoopMap); scf::ForOp parallelAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, @@ -534,7 +567,8 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { llvm::SmallVector &inductionVars, DenseMap &originalOperandLoopArgsMap, DenseMap &loopArgsOriginalOperandMap, - DenseMap &forResultOrignalResultMap); + DenseMap &forResultOrignalResultMap, + DenseMap> &indiceLoopMap); vector::TransferReadOp cloneReductionTransferRead( Value &source, OpBuilder &b, IRMapping &readMap, @@ -561,6 +595,20 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const size_t steps, const Location &loc, SmallVector &inductionVars, const ValueRange &iterArgs); + /// rectify indice for transfer_write operation + /// e.g.: vector.transfer_write"(%16, %9, %c0, %c0), the first %c0 should use + /// original indice not create by us + void rectifyWriteOperationIndice(vector::TransferWriteOp *originalWriteOp, + SmallVectorImpl &writeVars); + /// rectify indice for transfer_read operation, like broadcast operation + /// fusion by transfer_read , but the transfer_read operation is in innermost + /// for loop body, we must set correct for loop var. e.g.: + /// vector.transfer_read"(%16, %9, %c0), the first %c0 should use correct for + /// innermost loop iter vars + void rectifyReadOperationIndice(vector::TransferReadOp *originalReadOp, + VectorType loopType, + ArrayRef inductionVars, + SmallVectorImpl &readVars); }; class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index db14a1999..f905e69f4 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -25,9 +25,12 @@ namespace { // TODO: remove it in the future bool disableSpecialOp = false; +bool disableBroadcastOp = false; void printGroupOps(SmallVector, 8> &opGroups) { - for (auto grp : opGroups) { + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + llvm::outs() << idx << " group id." + << "\n"; if (grp.empty()) { continue; } @@ -465,16 +468,15 @@ float bfloat2float(uint16_t bfloatBits) { } bool isReadWriteOnLastDim(Operation *op) { - if (mlir::isa(op) || - mlir::isa(op)) { + if (isa(op)) { auto permutationMap = - mlir::dyn_cast(op) - ? mlir::dyn_cast(op).getPermutationMap() - : mlir::dyn_cast(op).getPermutationMap(); + dyn_cast(op) + ? dyn_cast(op).getPermutationMap() + : dyn_cast(op).getPermutationMap(); auto rank = - mlir::dyn_cast(op) - ? mlir::dyn_cast(op->getOperand(0).getType()).getRank() - : mlir::dyn_cast(op->getOperand(1).getType()).getRank(); + dyn_cast(op) + ? dyn_cast(op->getOperand(0).getType()).getRank() + : dyn_cast(op->getOperand(1).getType()).getRank(); auto dimExpr = permutationMap.getResults(); bool find = false; for (auto &expr : dimExpr) { @@ -632,8 +634,8 @@ void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, const AffineMap &permutationMap) { auto dimExpr = permutationMap.getResults(); - auto lastDim = mlir::dyn_cast(dimExpr.back()); - assert(mlir::isa(lastDim)); + auto lastDim = dyn_cast(dimExpr.back()); + assert(isa(lastDim)); SmallVector affineExprs; affineExprs.push_back(lastDim); @@ -992,7 +994,8 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, const DenseMap &operandIdxMap, const ValueRange &loopState, DenseMap &originalOperandLoopArgsMap, - std::queue &opQueue) { + std::queue &opQueue, + DenseMap> &indiceLoopMap) { auto &opPermuationMap = getOpPermuationMap(); auto tmpQ(opQueue); while (!tmpQ.empty()) { @@ -1002,7 +1005,7 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( // check operation type to set correct operand setOperationCorrectOperand(x, loopState, operandIdxMap, originalOperandLoopArgsMap, inductionVars, - opPermuationMap); + opPermuationMap, indiceLoopMap); } } @@ -1026,6 +1029,7 @@ bool hasOtherOperations(const std::queue &opQ, void ForLoopGenerator::getResultInCurrentOps( const size_t anchorIdx, const size_t groupId, const std::queue ops, SmallVector &results, + DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap) { auto tmpQ(ops); llvm::MapVector> &groupResults = @@ -1038,6 +1042,7 @@ void ForLoopGenerator::getResultInCurrentOps( std::pair retType = groupResults[curResult]; if (needReturnResult(retType, anchorIdx)) { results.emplace_back(curResult); + nextAnchorResultsIdxMap[curResult] = results.size() - 1; forResultOrignalResultMap[curResult] = curResult; } } @@ -1153,7 +1158,8 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( std::queue &candidateQueue, std::queue &movedQueue, DenseMap &originalOperandLoopArgsMap, - DenseMap &LoopArgsoriginalOperandMap) { + DenseMap &LoopArgsoriginalOperandMap, + DenseMap> &indiceLoopMap) { // 1. get operations in current anchor position std::queue movingOperation; @@ -1163,9 +1169,9 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( rewriteOperationAsVectorize(b, groupIdx, &movingOperation); // 3. move opeartions to current for block - moveOperationsToCurrentForBody(groupIdx, b, inductionVars, - currentLoopStateIdxMap, loopState, - originalOperandLoopArgsMap, movingOperation); + moveOperationsToCurrentForBody( + groupIdx, b, inductionVars, currentLoopStateIdxMap, loopState, + originalOperandLoopArgsMap, movingOperation, indiceLoopMap); // 4. get next anchor args getInitArgsToNextAnchor(anchorIdx, groupIdx, candidateQueue, loopState, @@ -1189,7 +1195,8 @@ void ForLoopGenerator::movePostOpToCurrentAnchor( DenseMap &originalOperandLoopArgsMap, DenseMap &loopArgsOriginalOperandMap, const SmallVector &nextAnchorResults, - DenseMap &forResultOrignalResultMap) { + DenseMap &forResultOrignalResultMap, + DenseMap> &indiceLoopMap) { // 1. move post-op to current loop body std::queue movingOperations; @@ -1199,7 +1206,7 @@ void ForLoopGenerator::movePostOpToCurrentAnchor( moveOperationsToCurrentForBody(anchorIdx, b, inductionVars, operandIdxMap, loopState, originalOperandLoopArgsMap, - movingOperations); + movingOperations, indiceLoopMap); // 2. replace correct for loop result to post-op IRRewriter rewriter(b); @@ -1223,7 +1230,7 @@ void ForLoopGenerator::generateLoopResults( SmallVector results; DenseMap currentResultMap; getResultInCurrentOps(anchorIdx, groupIdx, movedOperation, results, - currentResultMap); + nextAnchorResultsIdxMap, currentResultMap); llvm::MapVector> &groupResults = getGroupOpResults()[groupIdx]; @@ -1260,7 +1267,8 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( DenseMap &nextAnchorResultsIdxMap, llvm::SmallVector &inductionVars, DenseMap &forResultOrignalResultMap, - DenseMap &originalResultForResultMap) { + DenseMap &originalResultForResultMap, + DenseMap> &indiceLoopMap) { MultiReductionCanonicalizer rdCanonicalizer = getMultiRdCanonicalizers()[groupIdx]; @@ -1302,10 +1310,11 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( DenseMap currentArgsOriginalMap = loopArgsOriginalOperandMap; DenseMap originalArgsMap, argsOriginalMap; - movePreOpToCurrentAnchor( - anchorIdx, groupIdx, b, inductionVars, loopState, - currentLoopStateIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, - opQueue, movedOperation, originalArgsMap, argsOriginalMap); + movePreOpToCurrentAnchor(anchorIdx, groupIdx, b, inductionVars, + loopState, currentLoopStateIdxMap, + nextAnchorArgsIdxMap, nextAnchorArgs, + opQueue, movedOperation, originalArgsMap, + argsOriginalMap, indiceLoopMap); // replace reduction init args if (originalOperandLoopArgsMap.contains(multireductionOp.getAcc())) { @@ -1322,14 +1331,14 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( nextAnchorArgsIdxMap, nextAnchorArgs, originalArgsMap, argsOriginalMap, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, forResultOrignalResultMap, - originalResultForResultMap); + originalResultForResultMap, indiceLoopMap); // 3. move postOp to current body movePostOpToCurrentAnchor( b, anchorIdx, groupIdx, nxtFor->getResults(), b.getBlock(), opQueue, movedOperation, inductionVars, currentLoopStateIdxMap, loopState, currentoriginalArgsMap, currentArgsOriginalMap, - nextAnchorResults, forResultOrignalResultMap); + nextAnchorResults, forResultOrignalResultMap, indiceLoopMap); // 4. generate loop results generateLoopResults(b, loc, anchorIdx, groupIdx, nextAnchorResults, @@ -1344,7 +1353,8 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( size_t retIdx = nextAnchorArgsIdxMap[forResultOrignalResultMap[lastForResult]]; Value forRes = nxtFor->getResults()[retIdx]; - + // accumulate for loop iter args must be last, so we just put the + // reduction result as the last result nextAnchorResults.emplace_back(forRes); nextAnchorResultsIdxMap[forRes] = nextAnchorResults.size() - 1; forResultOrignalResultMap[forRes] = @@ -1366,6 +1376,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( } movingOperation.push(curOp); } + // remove all the multi_reduction operation while (!opQueue.empty()) { Operation *curOp = opQueue.front(); if (isa(curOp)) { @@ -1379,44 +1390,47 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( moveOperationsToCurrentForBody( groupIdx, b, inductionVars, currentLoopStateIdxMap, loopState, - originalOperandLoopArgsMap, movingOperation); - if (!rdCanonicalizer.getIsEmptyReduction()) { - int accValIdx = currentLoopStateIdxMap - [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; + originalOperandLoopArgsMap, movingOperation, indiceLoopMap); + // if (!rdCanonicalizer.getIsEmptyReduction()) { + int accValIdx = currentLoopStateIdxMap + [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; + // check acc val is the first args + assert(accValIdx == 0); - Value reductionResult = makeArithReduction( - b, loc, multireductionOp.getKind(), - multireductionOp.getSource(), loopState[accValIdx]); + Value reductionResult = makeArithReduction( + b, loc, multireductionOp.getKind(), multireductionOp.getSource(), + loopState[accValIdx]); - movePostOpToCurrentAnchor( - b, anchorIdx, groupIdx, ValueRange(), b.getBlock(), opQueue, - movingOperation, inductionVars, currentLoopStateIdxMap, - loopState, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, nextAnchorResults, - forResultOrignalResultMap); - - nextAnchorResults.clear(); - nextAnchorResults.emplace_back(reductionResult); - nextAnchorResultsIdxMap[reductionResult] = 0; - forResultOrignalResultMap[reductionResult] = - multireductionOp->getResults()[0]; - originalResultForResultMap[multireductionOp->getResults()[0]] = - reductionResult; - getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, - nextAnchorResults, forResultOrignalResultMap); + movePostOpToCurrentAnchor( + b, anchorIdx, groupIdx, ValueRange(), b.getBlock(), opQueue, + movingOperation, inductionVars, currentLoopStateIdxMap, loopState, + originalOperandLoopArgsMap, loopArgsOriginalOperandMap, + nextAnchorResults, forResultOrignalResultMap, indiceLoopMap); - } else { - Value sourceVal = multireductionOp.getSource(); - nextAnchorResults.clear(); - nextAnchorResults.emplace_back(sourceVal); - nextAnchorResultsIdxMap[sourceVal] = 0; - forResultOrignalResultMap[sourceVal] = - multireductionOp->getResults()[0]; - originalResultForResultMap[multireductionOp->getResults()[0]] = - sourceVal; - getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, - nextAnchorResults, forResultOrignalResultMap); - } + nextAnchorResults.clear(); + nextAnchorResults.emplace_back(reductionResult); + nextAnchorResultsIdxMap[reductionResult] = 0; + forResultOrignalResultMap[reductionResult] = + multireductionOp->getResults()[0]; + originalResultForResultMap[multireductionOp->getResults()[0]] = + reductionResult; + getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, + nextAnchorResults, nextAnchorResultsIdxMap, + forResultOrignalResultMap); + + // } else { + // Value sourceVal = multireductionOp.getSource(); + // nextAnchorResults.clear(); + // nextAnchorResults.emplace_back(sourceVal); + // nextAnchorResultsIdxMap[sourceVal] = 0; + // forResultOrignalResultMap[sourceVal] = + // multireductionOp->getResults()[0]; + // originalResultForResultMap[multireductionOp->getResults()[0]] = + // sourceVal; + // getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, + // nextAnchorResults, + // forResultOrignalResultMap); + // } maybeYieldValue(b, loc, nextAnchorResults); } }); @@ -1434,7 +1448,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( SmallVector &inductionVars, DenseMap &originalOperandLoopArgsMap, DenseMap &loopArgsOriginalOperandMap, - DenseMap &forResultOrignalResultMap) { + DenseMap &forResultOrignalResultMap, + DenseMap> &indiceLoopMap) { auto &rdCanonicalizer = getMultiRdCanonicalizers()[groupIdx]; vector::MultiDimReductionOp &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; @@ -1490,14 +1505,14 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( parallelIdx, groupIdx, b, inductionVars, loopState, currentLoopStateIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, opQueue, movedQueue, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap); + loopArgsOriginalOperandMap, indiceLoopMap); if (parallelIdx == parallelAxis.size() - 1) { - // Ensure accumalate expression in this parallel anchor position. - // If it not appear in current anchor, we must move it in here. + // Ensure accumalate expression appear in this parallel anchor + // position. If it not appear in current anchor, we must move it in + // here. // 1. delete it in operation queue // 2. move it in current movedqueue - DenseMap> srcOpCanoniclizedMap; DenseSet argsSet(nextAnchorArgs.begin(), nextAnchorArgs.end()); std::queue checkAccQueue(movedQueue); @@ -1517,6 +1532,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( break; } if (accInitVal) { + // we put initVal at last for loop args if (!argsSet.contains(accInitVal)) { nextAnchorArgs.emplace_back(accInitVal); nextAnchorArgsIdxMap[accInitVal] = nextAnchorArgs.size() - 1; @@ -1540,7 +1556,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( b, groupIdx, parallelIdx + 1, nextAnchorArgsIdxMap, nextAnchorArgs, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, forResultOrignalResultMap); + loopArgsOriginalOperandMap, forResultOrignalResultMap, + indiceLoopMap); } else if (parallelAxis.size() - 1 == parallelIdx) { nxtFor = reductionAxisGenerateForLoop( @@ -1548,7 +1565,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( nextAnchorArgs, originalOperandLoopArgsMap, loopArgsOriginalOperandMap, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, - forResultOrignalResultMap, originalResultForResultMap); + forResultOrignalResultMap, originalResultForResultMap, + indiceLoopMap); } // 3. move postOp to current body @@ -1557,7 +1575,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( nxtFor->getBlock(), opQueue, movedQueue, inductionVars, currentLoopStateIdxMap, loopState, currentOriginalOperandMap, currentOperandOriginalMap, nextAnchorResults, - forResultOrignalResultMap); + forResultOrignalResultMap, indiceLoopMap); // 4. generate loop results generateLoopResults(b, loc, parallelIdx, groupIdx, nextAnchorResults, @@ -1567,6 +1585,11 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( } else if (parallelIdx == parallelAxis.size()) { + DenseMap tmpOriginOperandLoopArgsMap = + originalOperandLoopArgsMap; + DenseMap tmpLoopArgsOriginalOperandMap = + loopArgsOriginalOperandMap; + // get accumualte value Attribute initValueAttr; getReductionInitAttr(multiReductionOp, initValueAttr); @@ -1576,9 +1599,11 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( getVectorzedType(multiReductionOp, dimSize), {initValueAttr})); + // put accumulte val at first for loop args DenseMap localAnchorArgsIdxMap; DenseMap localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap; + SmallVector argsArray; argsArray.emplace_back(accVal); localAnchorArgsIdxMap[accVal] = 0; @@ -1602,11 +1627,14 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( b, groupIdx, 0, parallelIdx, localAnchorArgsIdxMap, argsArray, localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, - forResultOrignalResultMap, originalResultForResultMap); + forResultOrignalResultMap, originalResultForResultMap, + indiceLoopMap); // insert accumulate value to original vector - // TODO: fix first accumualte idx use map - auto accRes = nxtFor->getResults()[0]; + Value nxtForAccVal = + originalResultForResultMap[multiReductionOp->getResults()[0]]; + size_t accIdx = nextAnchorResultsIdxMap[nxtForAccVal]; + auto accRes = nxtFor->getResults()[accIdx]; Operation *reductionOp = b.create( loc, multiReductionOp.getKind(), accRes); @@ -1614,29 +1642,30 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( loc, reductionOp->getResult(0), loopState[accLoopStateIdx], iv); // generate loop result - SmallVector currentAnchorResults; + SmallVector currentAnchorResults(loopState.size()); DenseMap currentResultMap; DenseMap currentResultIdxMap; - currentAnchorResults.emplace_back(insertOp->getResults()[0]); + currentAnchorResults[accLoopStateIdx] = insertOp->getResults()[0]; // reduce axis for loop first result we has already processed above currentResultMap[insertOp->getResults()[0]] = multiReductionOp->getResults()[0]; - currentResultIdxMap[insertOp->getResults()[0]] = 0; - + currentResultIdxMap[insertOp->getResults()[0]] = accLoopStateIdx; for (auto [idx, x] : llvm::enumerate(nextAnchorResults)) { - if (idx == 0) { + if (forResultOrignalResultMap[x] == + multiReductionOp->getResults()[0]) { continue; } Value originalResult = forResultOrignalResultMap[x]; - size_t forResultIdx = nextAnchorResultsIdxMap[x]; - currentAnchorResults.emplace_back( - nxtFor->getResults()[forResultIdx]); - currentResultIdxMap[nxtFor->getResults()[forResultIdx]] = idx; - currentResultMap[nxtFor->getResults()[forResultIdx]] = - originalResult; + size_t itrIdx = currentLoopStateIdxMap + [tmpOriginOperandLoopArgsMap[originalResult]]; + currentAnchorResults[itrIdx] = nxtFor->getResults()[idx]; + currentResultIdxMap[nxtFor->getResults()[idx]] = itrIdx; + currentResultMap[nxtFor->getResults()[idx]] = originalResult; } nextAnchorResults.clear(); + forResultOrignalResultMap.clear(); + nextAnchorResultsIdxMap.clear(); nextAnchorResults = std::move(currentAnchorResults); forResultOrignalResultMap = std::move(currentResultMap); nextAnchorResultsIdxMap = std::move(currentResultIdxMap); @@ -1726,6 +1755,8 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { getMultiRdCanonicalizers()[grpIdx]; vector::MultiDimReductionOp multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; + SmallVector ¶llelAxis = rdCanonicalizer.getParallelAxis(); + SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); std::queue &prevOps = rdCanonicalizer.getPrevOps(); std::queue &postOps = rdCanonicalizer.getPostOps(); std::queue &accRelatedOps = rdCanonicalizer.getAccRelatedOps(); @@ -1738,6 +1769,27 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { getPostOps(postOps, copyOpQueue, multiReductionOp); classifyAccRelatedOps(accRelatedOps, sourceRelatedOps, multiReductionOp.getAcc().getDefiningOp(), prevOps); + + // mark source read operation need to set correct for loop var idx + std::queue tmpSourceQ(sourceRelatedOps); + DenseMap> indiceLoopMap; + DenseMap varLoopIdxMap; + VectorType groupVector = + getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; + for (size_t i = 0; i < parallelAxis.size(); i++) { + varLoopIdxMap[parallelAxis[i]] = i; + } + for (size_t i = parallelAxis.size(); i < groupVector.getRank(); i++) { + varLoopIdxMap[reductionAxis[i - parallelAxis.size()]] = i; + } + while (!tmpSourceQ.empty()) { + auto *curOp = tmpSourceQ.front(); + tmpSourceQ.pop(); + if (isa(curOp)) { + getCurrentGroupIndiceLoopMap(indiceLoopMap, grpIdx, curOp, varLoopIdxMap); + } + } + // move accumulate related operation to operation first std::queue rectifyQueue; DenseSet pushedSet; @@ -1779,8 +1831,7 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { scf::ForOp forOp = parallelAxisGenerateForLoop( opBuilder, grpIdx, 0, currentLoopStateIdxMap, initArgs, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, forResultOrignalResultMap); - + loopArgsOriginalOperandMap, forResultOrignalResultMap, indiceLoopMap); auto replaceIfFn = [&](OpOperand &use) { return use.getOwner()->getBlock() == forOp->getBlock(); }; @@ -1802,7 +1853,7 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( const ValueRange &iterArgs, DenseMap &tpAxisMap) { auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; - VectorType vtType = tpOp.getVector().getType(); + VectorType vtType = tpOp.getResultVectorType(); size_t rank = vtType.getRank(); auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); @@ -1846,6 +1897,8 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( itrIdx++; } + rectifyWriteOperationIndice(&successorWriteOp, writeVars); + auto writeOp = b.create( loc, transferReadOp->getResults()[0], loopState[0], writeVars, inBoundsVal); @@ -1890,8 +1943,6 @@ scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( // inner most body of the loop if (forDimIdx == rank - 1) { - sourceType.dump(); - destType.dump(); // transfer read from source tensor Value source = scOp->getOperand(0); auto readSourceOp = @@ -1981,9 +2032,12 @@ scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( /*padding=*/padValue, /*inBounds=*/inBoundsVal); + SmallVector writeVars = + loopType == sourceType ? smallRankShapeVars : inductionVars; + + rectifyWriteOperationIndice(&successorWriteOp, writeVars); auto writeOp = b.create( - loc, transferReadOp->getResults()[0], loopState[0], - loopType == sourceType ? smallRankShapeVars : inductionVars, + loc, transferReadOp->getResults()[0], loopState[0], writeVars, inBoundsVal); maybeYieldValue(b, loc, writeOp->getResults()); } else { @@ -1995,6 +2049,44 @@ scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( }); } +void ForLoopGenerator::rectifyWriteOperationIndice( + vector::TransferWriteOp *originalWriteOp, + SmallVectorImpl &writeVars) { + VectorType sucessWriteVectorType = originalWriteOp->getVectorType(); + ShapedType successWriteTensorType = + cast(originalWriteOp->getResultTypes()[0]); + size_t inMutableIdx = + successWriteTensorType.getRank() - sucessWriteVectorType.getRank(); + Operation::operand_range writeIndices = originalWriteOp->getIndices(); + + for (size_t i = 0; i < inMutableIdx; i++) { + writeVars[i] = (writeIndices[i]); + } +} + +void ForLoopGenerator::rectifyReadOperationIndice( + vector::TransferReadOp *originalReadOp, VectorType loopType, + ArrayRef inductionVars, SmallVectorImpl &readVars) { + VectorType originalReadVectorType = originalReadOp->getVectorType(); + // currently only broadcast (fuse as transfer_read) will move into more inner + // loop + // TODO: Need to better process the broadcast operation + if (originalReadVectorType.getRank() - 1 < + getFusionStrategy().getOpAnchorPos()[*originalReadOp]) { + return; + } + int64_t itrIdx = loopType.getRank() - 1; + int64_t readIdx = originalReadVectorType.getRank() - 1; + while (itrIdx >= 0 and readIdx >= 0) { + if (originalReadVectorType.getShape()[readIdx] == + loopType.getShape()[itrIdx]) { + readVars[readIdx] = inductionVars[itrIdx]; + readIdx--; + } + itrIdx--; + } +} + /// generate transpose for loop scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { @@ -2041,6 +2133,25 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { return forOp; } +/// mark which operation need to set correct for loop var idx +/// due to sometimes we need to chage for loop order like reduce operation. +void ForLoopGenerator::getCurrentGroupIndiceLoopMap( + DenseMap> &indiceLoopMap, + const size_t groupId, Operation *op, + const DenseMap &setIdxMap) { + if (setIdxMap.empty()) { + DenseMap forIdxMap; + VectorType groupVector = + getFusionStrategy().getGroupBiggestRankVectorType()[groupId]; + for (size_t i = 0; i < groupVector.getRank(); i++) { + forIdxMap[i] = i; + } + indiceLoopMap[op] = forIdxMap; + return; + } + indiceLoopMap[op] = setIdxMap; +} + void ForLoopGenerator::clearCurrentOperationGroup(size_t grpIdx) { std::queue().swap(getFusionStrategy().getOpGroups()[grpIdx]); }; @@ -2052,8 +2163,9 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { TransposeCanonicalizer &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; + IRRewriter rewriter(func); - VectorType vtType = tpOp.getVector().getType(); + VectorType vtType = tpOp.getResultVectorType(); size_t rank = vtType.getRank(); if (rank < 2) { llvm::llvm_unreachable_internal( @@ -2065,13 +2177,7 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { ArrayRef permutation = tpOp.getPermutation(); DenseSet permuteSet(permutation.begin(), permutation.end()); bool isTwoDTranspose = tpCanonicalizer.isTwoDTranspose(); - const int tpStep = 16; - // currently we only support shape that is an integer multiple of tpStep - if (vtType.getShape()[tpCanonicalizer.getFirstTpIdx()] % tpStep != 0 or - vtType.getShape()[tpCanonicalizer.getSecondTpIdx()] % tpStep != 0) { - isTwoDTranspose = false; - } - OpBuilder b(tpOp); + vector::TransferWriteOp successorWriteOp; for (Operation *x : tpOp->getUsers()) { if (isa(x)) { @@ -2082,17 +2188,56 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { // iterArgs.emplace_back(successorWriteOp->getOperands()[1]); SmallVector operands; DenseMap operandIdxMap; - DenseMap originalOperandMap; - auto &initArgs = getGroupOpInitArgs()[grpIdx]; + DenseMap originalOperandMap, operandOriginalMap, resultIdxMap, + forResultOrignalResultMap; + SetVector &initArgs = getGroupOpInitArgs()[grpIdx]; for (Value x : initArgs) { operands.emplace_back(x); operandIdxMap[x] = operands.size() - 1; originalOperandMap[x] = x; + operandOriginalMap[x] = x; } SmallVector iterArgs(operands.begin(), operands.end()); SmallVector inductionVars; - IRRewriter rewriter(func); + // don't need to do the transpose + if (tpCanonicalizer.isTransposeOnAllOneDim()) { + removeOpInCurrentGroups(grpIdx, tpOp); + + // generate nested for loop + SmallVector nextLoopResults; + DenseMap resultIdxMap; + SmallVector inductionVars; + DenseMap forResultOrignalResultMap; + Operation *firstOp = getFusionStrategy().getOpGroups()[grpIdx].front(); + OpBuilder b(firstOp); + VectorType groupVector = + getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; + ArrayRef shapes = groupVector.getShape(); + + DenseMap> indiceLoopMap; + + scf::ForOp forOp = constructNestedForOp( + 0, grpIdx, b, firstOp->getLoc(), iterArgs, shapes, inductionVars, + operandIdxMap, originalOperandMap, operandOriginalMap, nextLoopResults, + resultIdxMap, forResultOrignalResultMap, indiceLoopMap); + + auto replaceIfFn = [&](OpOperand &use) { + return use.getOwner()->getBlock() == forOp->getBlock(); + }; + for (auto x : nextLoopResults) { + auto originalResult = forResultOrignalResultMap[x]; + rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), + forOp->getResults()[resultIdxMap[x]], + replaceIfFn); + } + // clear current group operation + clearCurrentOperationGroup(grpIdx); + return forOp; + } + OpBuilder b(tpOp); + int tpStep = TransposeCanonicalizer::TRANSPOSE_KERNEL::KERNEL_16X16; + // only contains last dim can use fast transpose algorithm if (permuteSet.contains(rank - 1) and isTwoDTranspose) { scf::ForOp forOp = generateTransposeForLoopWithLastDim( b, grpIdx, 0, tpStep, tpOp.getLoc(), inductionVars, iterArgs, @@ -2190,8 +2335,27 @@ void TransposeCanonicalizer::prepareSpecialOperationInfo() { } } +bool TransposeCanonicalizer::isTransposeOnAllOneDim() { + vector::TransposeOp tpOp = getCandidateOps()[0]; + ArrayRef permutation = tpOp.getPermutation(); + VectorType tpVectorType = tpOp.getResultVectorType(); + int64_t itrIdx = 0; + while (itrIdx < tpVectorType.getRank()) { + if (itrIdx == permutation[itrIdx]) { + itrIdx++; + continue; + } + if (tpVectorType.getShape()[itrIdx] != 1) { + return false; + } + itrIdx++; + } + return true; +} + bool TransposeCanonicalizer::isTwoDTranspose() { ArrayRef permutation = getCandidateOps()[0].getPermutation(); + size_t rank = permutation.size(); int diffCount = 0; // get the first transpose axis @@ -2219,6 +2383,13 @@ bool TransposeCanonicalizer::isTwoDTranspose() { } itrIdx++; } + const int tpStep = 16; + VectorType vtType = getCandidateOps()[0].getResultVectorType(); + // currently we only support shape that is an integer multiple of tpStep + if (vtType.getShape()[getFirstTpIdx()] % tpStep != 0 or + vtType.getShape()[getSecondTpIdx()] % tpStep != 0) { + return false; + } return diffCount == 2; } @@ -2388,15 +2559,13 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. - // printGroupOps(getFusionStrategy().getOpGroups()); - // std::cout << "___________ before analysis ________________" - // << "\n"; + printGroupOps(getFusionStrategy().getOpGroups()); + std::cout << "___________ before analysis ________________" + << "\n"; analysisGroupOperaion(); - // std::cout << "___________ after analysis ________________" - // << "\n"; - // printGroupOps(getFusionStrategy().getOpGroups()); - - func->dump(); + std::cout << "___________ after analysis ________________" + << "\n"; + printGroupOps(getFusionStrategy().getOpGroups()); // Speical Operation Canonicalization canonicalizeSpecialOperation(); @@ -2447,12 +2616,13 @@ void CanonicalizerVectorOperation::run() { } /// -void setOperationCorrectOperand( +void ForLoopGenerator::setOperationCorrectOperand( Operation *op, const ValueRange &iterArgs, const DenseMap &operandIdxMap, DenseMap &originalOperandLoopArgsMap, ArrayRef inductionVars, - const DenseMap &opPermuationMap) { + const DenseMap &opPermuationMap, + DenseMap> &indiceloopMap) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { if (!originalOperandLoopArgsMap.contains(opd)) { continue; @@ -2472,37 +2642,58 @@ void setOperationCorrectOperand( auto dimExpr = permutationMap.getResults(); for (auto [idx, x] : llvm::enumerate(dimExpr)) { - if (!isa(x)) { + if (!isa(x)) { llvm::llvm_unreachable_internal( "Permuatation map must contains dim expr."); } - auto dim = dyn_cast(x).getPosition(); + size_t dim; + if (auto d = dyn_cast(x)) { + dim = d.getPosition(); + } + if (auto d = dyn_cast(x)) { + dim = d.getValue(); + } ShapedType tensorType = cast(op->getOperandTypes()[offset - 1]); + size_t varIdx = dim; if (tensorType.getRank() > (int64_t)inductionVars.size()) { int64_t tensorOffset = tensorType.getRank() - inductionVars.size(); - - op->setOperand(dim + offset, inductionVars[dim - tensorOffset]); - continue; + if (dim < tensorOffset) { + continue; + } + varIdx = dim - tensorOffset; } - - op->setOperand(dim + offset, inductionVars[dim]); + if (indiceloopMap.contains(op)) { + op->setOperand(dim + offset, inductionVars[indiceloopMap[op][varIdx]]); + } else { + op->setOperand(dim + offset, inductionVars[varIdx]); + } + } + if (auto readOp = dyn_cast(op)) { + size_t grpIdx = getFusionStrategy().getOpGroupIndexMap()[op]; + VectorType loopType = + getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; + SmallVector readIndices(readOp.getIndices().begin(), + readOp.getIndices().end()); + rectifyReadOperationIndice(&readOp, loopType, inductionVars, readIndices); + readOp.getIndicesMutable().assign(readIndices); } } } scf::ForOp ForLoopGenerator::constructNestedForOp( const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, - const Location &loc, const ValueRange &iterArgs, VectorType type, + const Location &loc, const ValueRange &iterArgs, const ArrayRef &dims, SmallVector &inductionVars, DenseMap &operandIdxMap, DenseMap &originalOperandMap, DenseMap &operandOriginalMap, SmallVector &nextAnchorResults, DenseMap &nextAnchorResultsIdxMap, - DenseMap &forResultOrignalResultMap) { - const int loop_step = getDataTypeValidSteps(type); + DenseMap &forResultOrignalResultMap, + DenseMap> &indiceLoopMap) { + const int loop_step = getFusionStrategy().getGroupMaxSteps()[groupIdx]; // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); auto forSteps = makeIndexArithConstantOp( @@ -2511,7 +2702,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( // Create a loop and move vectorized operation into loops. auto forOp = b.create( - b.getUnknownLoc(), zero, numIter, forSteps, iterArgs, + loc, zero, numIter, forSteps, iterArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { inductionVars.emplace_back(iv); @@ -2527,12 +2718,13 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( rewriteOperationAsVectorize(b, groupIdx, &movingOperation); // 3. move opeartions to current for block - moveOperationsToCurrentForBody(groupIdx, b, inductionVars, - operandIdxMap, loopState, - originalOperandMap, movingOperation); + moveOperationsToCurrentForBody( + groupIdx, b, inductionVars, operandIdxMap, loopState, + originalOperandMap, movingOperation, indiceLoopMap); getResultInCurrentOps(forDimIdx, groupIdx, movingOperation, - nextAnchorResults, forResultOrignalResultMap); + nextAnchorResults, nextAnchorResultsIdxMap, + forResultOrignalResultMap); maybeYieldValue(b, loc, nextAnchorResults); } else { // outter loop @@ -2549,25 +2741,25 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( movePreOpToCurrentAnchor( forDimIdx, groupIdx, b, inductionVars, loopState, operandIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, opQueue, movedQueue, - originalOperandMap, operandOriginalMap); + originalOperandMap, operandOriginalMap, indiceLoopMap); auto nxtFor = constructNestedForOp( - forDimIdx + 1, groupIdx, b, loc, loopState, type, dims, - inductionVars, nextAnchorArgsIdxMap, originalOperandMap, - operandOriginalMap, nextAnchorResults, nextAnchorResultsIdxMap, - forResultOrignalResultMap); + forDimIdx + 1, groupIdx, b, loc, loopState, dims, inductionVars, + nextAnchorArgsIdxMap, originalOperandMap, operandOriginalMap, + nextAnchorResults, nextAnchorResultsIdxMap, + forResultOrignalResultMap, indiceLoopMap); movePostOpToCurrentAnchor( b, forDimIdx, groupIdx, nxtFor->getResults(), b.getBlock(), opQueue, movedQueue, inductionVars, operandIdxMap, loopState, currentOriginalOperandMap, currentOperandOriginalMap, - nextAnchorResults, forResultOrignalResultMap); + nextAnchorResults, forResultOrignalResultMap, indiceLoopMap); generateLoopResults(b, loc, forDimIdx, groupIdx, nextAnchorResults, nextAnchorResultsIdxMap, nxtFor->getResults(), movedQueue, forResultOrignalResultMap); - maybeYieldValue(b, loc, nxtFor->getResults()); + maybeYieldValue(b, loc, nextAnchorResults); } }); return forOp; @@ -2805,6 +2997,9 @@ bool hasDataDependency(Operation *op1, Operation *op2) { if (isa(op2)) { return true; } + if (isa(op1) and disableBroadcastOp) { + return true; + } // only special operation may cause data dependency if (!isSpecialOp(op1)) { @@ -2877,6 +3072,9 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return true; }) .Case([&](vector::BroadcastOp broadcastOp) { + if (isa(op2)) { + return false; + } return !OpTrait::util::staticallyKnownBroadcastable( getOperationVectorType(op1, false)->getShape(), getOperationVectorType(op2)->getShape()); @@ -3165,8 +3363,59 @@ mlir::FailureOr getOperationOperateTensor(Operation *op) { }); } +void CanonicalizerCommonUsedData::removeOpInCurrentGroups(size_t grpIdx, + Operation *op) { + std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[grpIdx]); + std::queue newOpQueue; + while (!tmpOpQueue.empty()) { + auto curOp = tmpOpQueue.front(); + tmpOpQueue.pop(); + if (curOp != op) { + newOpQueue.push(curOp); + continue; + } + getFusionStrategy().getOpGroupIndexMap().erase(curOp); + getFusionStrategy().getOpAnchorPos().erase(curOp); + } + getFusionStrategy().getOpGroups()[grpIdx] = newOpQueue; + + // erase and replace the operation + Operation *defOp = op->getOperand(0).getDefiningOp(); + SmallVector usesOp(op->getUsers().begin(), op->getUsers().end()); + IRRewriter rewriter(op); + rewriter.replaceOp(op, op->getOperand(0)); + // update removed operation related operation anchor position + getFusionStrategy().getOpAnchorPos()[defOp] = + getOperationMaxVectorType(defOp)->getRank() - 1; + for (Operation *x : usesOp) { + getFusionStrategy().getOpAnchorPos()[x] = + getOperationMaxVectorType(x)->getRank() - 1; + } + + // update operaiton in grpIdx group related information + updateOpGroupInfo(grpIdx); +} + +void CanonicalizerCommonUsedData::updateOpGroupInfo(size_t grpIdx) { + std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[grpIdx]); + // dummy init + VectorType currentMaxRankType = + getOperationMaxVectorType(tmpOpQueue.front()).value(); + getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx] = + currentMaxRankType; + + while (!tmpOpQueue.empty()) { + auto curOp = tmpOpQueue.front(); + tmpOpQueue.pop(); + VectorType type = getOperationMaxVectorType(curOp).value(); + if (type.getRank() > currentMaxRankType.getRank()) { + getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx] = type; + } + } +} + void CanonicalizerCommonUsedData::updateOpOperandResultInGroups( - size_t opGid, Operation *op, Value &init, const Value &result) { + size_t opGid, Operation *op, const Value &init, const Value &result) { std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[opGid]); std::queue newOpQueue; while (!tmpOpQueue.empty()) { @@ -3229,6 +3478,24 @@ void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( getOpPermuationMap()[writeOp] = writeOp.getPermutationMap(); } +template +Operation *CanonicalizerCommonUsedData::getNextTargetOperationInCurrentGroup( + Operation *curOp, const size_t grpIdx) { + std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[grpIdx]); + while (!tmpOpQueue.empty()) { + auto frontOp = tmpOpQueue.front(); + if (isa(frontOp)) { + for (auto x : frontOp->getOperands()) { + if (x.getDefiningOp() == curOp) { + return frontOp; + } + } + } + tmpOpQueue.pop(); + } + return nullptr; +} + void VectorOperationAnalyzer::analysisEmptyGroup() { SmallVector, 8> &opGroups = getFusionStrategy().getOpGroups(); @@ -3287,7 +3554,7 @@ void VectorOperationAnalyzer::specialOperationRectify( auto op = tmpQueue.front(); tmpQueue.pop(); // remain transfer read operation to do the broadcast fusion - if (isa(op)) { + if (isa(op) and !disableBroadcastOp) { auto srcOp = op->getOperand(0).getDefiningOp(); assert(isa(srcOp)); // just remain write operation, it's size will @@ -3296,6 +3563,8 @@ void VectorOperationAnalyzer::specialOperationRectify( if (tmpQueue.size() <= 1) { continue; } + getFusionStrategy().getOpAnchorPos()[srcOp] = + getFusionStrategy().getOpAnchorPos()[op]; rewriter.replaceOp(op, srcOp); continue; } @@ -3398,12 +3667,26 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { continue; } if (!srcOpCanoniclizedMap.contains(sourceOp)) { - generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap, - OpAnchorPos[sourceOp], rtKind, - visitedOperation); + // get write operation + if (auto writeOp = getNextTargetOperationInCurrentGroup< + vector::TransferWriteOp>(sourceOp, sourceOpGid)) { + auto writeOpresult = writeOp->getResults()[0]; + auto writeTensor = writeOp->getOperands()[1]; + srcOpCanoniclizedMap.insert( + {sourceOp, {writeTensor, writeOpresult}}); + groupOpInitArgs[sourceOpGid].insert(writeTensor); + updateReturnResultKind(writeOp, sourceOpGid, rtKind); + } else { + generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap, + OpAnchorPos[sourceOp], rtKind, + visitedOperation); + } } else { // udpate result return type - updateReturnResultKind(sourceOp, sourceOpGid, rtKind); + // updateReturnResultKind(sourceOp, sourceOpGid, rtKind); + updateReturnResultKind( + srcOpCanoniclizedMap[sourceOp].second.getDefiningOp(), + sourceOpGid, rtKind); } auto opInit = canonicalizeCurrentOperation( @@ -3472,7 +3755,8 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { // if (!srcOpCanoniclizedMap.contains(sourceOp)) { // auto [tsr, writeOpresult] = // canonicalizeSourceOperation(sourceOp, visitedOperation); - // srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); + // srcOpCanoniclizedMap.insert({sourceOp, {tsr, + // writeOpresult}}); // } // auto opInit = canonicalizeCurrentOperation( // op, srcOpCanoniclizedMap[sourceOp].second, idx); @@ -3483,7 +3767,8 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { } } } - // if (mlir::isa(op) && !movedOperationSet.contains(op)) { + // if (mlir::isa(op) && !movedOperationSet.contains(op)) + // { // auto parentBlock = op->getBlock(); // std::stack opStack; @@ -3521,12 +3806,12 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( ValueRange forIterArgs(operands); ArrayRef shapes = vectorType.getShape(); SmallVector inductionVars; + DenseMap> indiceLoopMap; // generate for loop auto forOp = constructNestedForOp( - 0, groupId, rewriter, rewriter.getUnknownLoc(), forIterArgs, vectorType, - shapes, inductionVars, operandIdxMap, originalOperandMap, - operandOriginalMap, nextLoopResults, resultIdxMap, - forResultOrignalResultMap); + 0, groupId, rewriter, rewriter.getUnknownLoc(), forIterArgs, shapes, + inductionVars, operandIdxMap, originalOperandMap, operandOriginalMap, + nextLoopResults, resultIdxMap, forResultOrignalResultMap, indiceLoopMap); auto replaceIfFn = [&](OpOperand &use) { return use.getOwner()->getBlock() == forOp->getBlock(); }; @@ -3591,7 +3876,7 @@ void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { rewriter.setInsertionPointAfter(grp.back()); // 1. Rewrite operation as vectorized form // 2. Generate loop - rewriteOperationAsVectorize(rewriter, idx); + // rewriteOperationAsVectorize(rewriter, idx); auto forOp = generateVectorizedForLoop(idx, rewriter, groupType); // special operation do not need to change anything if (failed(forOp)) { diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index 0049ed82c..e7fa461d7 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -512,6 +512,32 @@ struct TensorOpConvertVectorPass : public RewritePattern { } }; +struct EliminateWriteReadOpPass + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + auto sourceOp = op->getOperand(0).getDefiningOp(); + if (isa_and_nonnull(sourceOp)) { + rewriter.replaceOp(op, sourceOp->getOperand(0)); + return success(); + } + return failure(); + } +}; + +void eliminateWriteReadOperation(Operation *op) { + if (!isa_and_nonnull(op)) { + return; + } + auto sourceOp = op->getOperand(0).getDefiningOp(); + if (isa_and_nonnull(sourceOp)) { + IRRewriter rewriter(op); + rewriter.replaceOp(op, sourceOp->getOperand(0)); + } +} + /// Pass that lower to tile vector. void populateLowerToTileVectorPatterns(RewritePatternSet &patterns) { patterns.add, @@ -535,25 +561,39 @@ struct LowerTileVectorPass }; // some operation convert as constant, this pattern can help us to improve // the performance - tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn); + // tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn); // remove unnessary operation - tensor::populateReassociativeReshapeFoldingPatterns(patterns); - tensor::populateFoldTensorSubsetOpPatterns(patterns); - tensor::populateFoldTensorEmptyPatterns(patterns, true); - tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + // tensor::populateReassociativeReshapeFoldingPatterns(patterns); + // tensor::populateFoldTensorSubsetOpPatterns(patterns); + // tensor::populateFoldTensorEmptyPatterns(patterns, true); + // tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); populateLowerToTileVectorPatterns(patterns); linalg::populatePadOpVectorizationPatterns(patterns); - // ensure read and write on last dimension - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - // remove unnessary broadcast operation - vector::populateSinkVectorBroadcastPatterns(patterns); - vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); - vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); - GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns), config); + // error case: + // due to insert slice tensor<1x32xf32> to tensor<1x128x1x32xf32> + // linalg.copy : <1x32xf32> + // -> transfer_write : permutation map = (d0, d1, d2, d3) -> (d0, d3) + // Inorder to avoid the fold greedily bug (fold wrong permution map for the + // transfer_write operation). Give it the new full IR to fold second time + // can fold correctly. + RewritePatternSet secondPattern(ctx); + // secondPattern.add(patterns.getContext()); + // ensure read and write on last dimension + vector::populateVectorTransferPermutationMapLoweringPatterns(secondPattern); + // remove unnessary broadcast operation + // vector::populateSinkVectorBroadcastPatterns(secondPattern); + // vector::TransferReadOp::getCanonicalizationPatterns(secondPattern, ctx); + // vector::TransferWriteOp::getCanonicalizationPatterns(secondPattern, ctx); + // tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); + + (void)applyPatternsAndFoldGreedily(funcOp, std::move(secondPattern)); + DominanceInfo domInfo; + IRRewriter rewriter(funcOp); + eliminateCommonSubExpressions(rewriter, domInfo, funcOp); } }; } // namespace From 37ea49b5d91d47a8dc42e2ca39b582a55f6bd7b3 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 13 Aug 2024 16:03:01 +0800 Subject: [PATCH 23/66] fix useless vector operation --- include/gc/Transforms/TilingVector.h | 32 +- lib/gc/Transforms/CPUPhysicalResigterPass.cpp | 509 +++++++++++------- lib/gc/Transforms/LowerTileVectorPass.cpp | 14 +- lib/gc/Transforms/Pipeline.cpp | 4 +- 4 files changed, 359 insertions(+), 200 deletions(-) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index ca28afbe8..3ccd7ed7e 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -8,6 +8,7 @@ #ifndef GC_PASSES_TILINGVECTOR_H #define GC_PASSES_TILINGVECTOR_H +#include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -42,7 +43,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include -#include #include #include #include @@ -124,18 +124,22 @@ class VectorFusionStrategy : public TypeHelper { VectorFusionStrategy &operator=(VectorFusionStrategy &&) = default; + /// Get the map which contains each group vector type which has biggest rank. llvm::SmallDenseMap &getGroupBiggestRankVectorType() { return groupBigestRankVectorType; }; + /// Get the operation group obtained by fusion strategy analysis SmallVector, 8> &getOpGroups() { return opGroups; } + /// Get the operation belong to which group index map DenseMap &getOpGroupIndexMap() { return opGroupIndexMap; } + /// Get the map contains max steps of each group llvm::SmallVector &getGroupMaxSteps() { return groupMaxSteps; } llvm::DenseMap &getOpAnchorPos() { return opAnchorPos; } func::FuncOp &getFunc() { return func; } - + /// Do fusion strategy void classifyOperations(); /// Whether two operations have compatible vector shapes @@ -154,6 +158,13 @@ class VectorFusionStrategy : public TypeHelper { void run(); }; +/// Has two kind: +/// 1. OperationGroup: +/// The operation is converted into physical registers through our fusion +/// strategy. +/// 2. Operations:(TODO:) +/// The user ensures that there is no data dependency between operations, +/// and we directly convert the operations into physical register sizes. enum CanonicalizerKind { OperationsGroup, Operations }; template class SpecialOperationCanonicalizer { @@ -510,7 +521,10 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { llvm::DenseMap &nextAnchorResultsIdxMap, const ValueRange &forResults, const std::queue &movedOperaiton, - DenseMap &forResultOrignalResultMap); + DenseMap &forResultOrignalResultMap, + ValueRange loopState, + DenseMap ¤tOperandOriginMap, + DenseMap &nextOperandIdxMap); /// todo: need to add a struct to remove so many parameters void movePostOpToCurrentAnchor( @@ -609,11 +623,16 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { VectorType loopType, ArrayRef inductionVars, SmallVectorImpl &readVars); + + /// rectify each group operand use for loop result + void rectifyGroupOperands(size_t currentGroupId, Value originalResult, + Value forResult); }; class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { private: func::FuncOp func; + DenseMap> srcOpCanoniclizedMap; public: virtual ~VectorOperationAnalyzer(){}; @@ -625,9 +644,14 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { // operation void analysisEmptyGroup(); void analysisGroupMaxSteps(); + /// analysis operation result of current group whether needed by other + /// operation void analysisGroupOperaion(); - void analysisGroupOperationResults(); + void specialOperationRectify(DenseMap &visitedOperation); + /// + void updateReturnResultKind(Operation *sourceOp, size_t sourceOpGid, + ReturnTypeKind rtKind); }; /// Vectorize vector operation with target machines simd instructions. class CanonicalizerVectorOperation : virtual public ForLoopGenerator, diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp index f905e69f4..9d4e6e657 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalResigterPass.cpp @@ -23,43 +23,48 @@ namespace { arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp, \ arith::TruncFOp, arith::TruncIOp +#define IMPLEMENTED_MATMUL \ + linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp, \ + linalg::BatchReduceMatmulOp, linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp, \ + linalg::MatmulOp, linalg::BatchMatmulOp, \ + linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ + linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ + linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp + // TODO: remove it in the future bool disableSpecialOp = false; -bool disableBroadcastOp = false; +bool disableBroadcastOp = true; +bool enableDebugPrinter = false; + +void printQueue(const std::queue &opQueue) { + auto tempQ(opQueue); + while (!tempQ.empty()) { + auto cur = tempQ.front(); + cur->dump(); + tempQ.pop(); + } +} void printGroupOps(SmallVector, 8> &opGroups) { for (auto [idx, grp] : llvm::enumerate(opGroups)) { - llvm::outs() << idx << " group id." - << "\n"; + llvm::outs() << " group id: " << idx << "\n"; if (grp.empty()) { continue; } llvm::outs() << "__________________ group start_____________" << "\n"; - std::queue tmpQ(grp); - while (!tmpQ.empty()) { - auto cur = tmpQ.front(); - tmpQ.pop(); - cur->dump(); - } + printQueue(grp); llvm::outs() << "__________________ group end_____________" << "\n"; } } -void printQueue(const std::queue &opQueue) { - std::cout << "________________________________ op Queue " - "__________________" - << std::endl; - auto tempQ(opQueue); - while (!tempQ.empty()) { - auto cur = tempQ.front(); - cur->dump(); - tempQ.pop(); - } - std::cout << "________________________________ op queue end " - "__________________" - << std::endl; +static inline bool isSpecialLinalgOp(Operation *op) { + return isa(op); +} + +static inline bool isReadOrWriteOperation(Operation *op) { + return isa(op); } /// whether op2 use op1 result @@ -69,9 +74,6 @@ template , T>> static bool isOperationsHasDefUseRelation(Operation *op1, Operation *op2) { - if (!isa(op1) and !isa(op2)) { - return false; - } for (Value opd : op2->getOperands()) { if (opd.getDefiningOp() == op1) { return true; @@ -90,12 +92,13 @@ static size_t getFirstTrueIndex(ArrayRef ararys) { return -1; } -bool isSpecialOp(Operation *op) { - return isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op); +static inline bool isSpecialOp(Operation *op) { + return isa( + op); } +/// operation should not contain for loop bool is_innermost_operation(Operation *op) { bool inner_most = true; op->walk([&inner_most](Operation *p) { @@ -108,20 +111,46 @@ bool is_innermost_operation(Operation *op) { return inner_most; } +/// whether operation is a not support operation bool isNotSupportOperation(Operation *op) { - return isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op); + return isa(op); } -bool isReadOrWriteOperation(Operation *op) { - return isa(op) || isa(op); +/// whether operation is operate on dynamic shape +bool hasDynamicShape(Operation *op) { + auto isDynamicShapedType = [](Value x) { + if (auto type = dyn_cast(x.getType())) { + if (ShapedType::isDynamicShape(type.getShape())) { + return true; + } + } + return false; + }; + // Check operands data type. + for (auto x : op->getOperands()) { + if (isDynamicShapedType(x)) { + return true; + } + } + // Check results data type. + for (auto x : op->getResults()) { + if (isDynamicShapedType(x)) { + return true; + } + } + return false; } // TODO: Need to support these operations in the future bool hasNotSupportOperation(func::FuncOp *func) { auto walkRes = func->walk([](Operation *op) { if (isNotSupportOperation(op)) { + LDBG("Operation do not support yet : " << *op << "\n"); + return WalkResult::interrupt(); + } + if (hasDynamicShape(op)) { + LDBG("Operation has dynamic shape: " << *op << "\n"); return WalkResult::interrupt(); } return WalkResult::advance(); @@ -129,7 +158,7 @@ bool hasNotSupportOperation(func::FuncOp *func) { return walkRes != WalkResult::advance(); } -// select nearest even step +/// select nearest even step int getNearestVectorStep(const int step) { assert(step > 0); int nbits = 0, n = step; @@ -141,32 +170,16 @@ int getNearestVectorStep(const int step) { return (1 << (nbits - 1)) == step ? step : (1 << nbits); } -int TypeHelper::generateValidSteps(int steps, VectorType type) { - return type.getShape().back() >= steps - ? (steps > 16 ? 16 : steps) - : getNearestVectorStep(type.getShape().back()); -} - -// expr equals `vector rank` - 1 +/// whether operate on last dimension bool isLastDim(const AffineExpr &expr, const size_t rank) { - return mlir::isa(expr) && - mlir::dyn_cast(expr).getPosition() == rank - 1; + return isa(expr) && + dyn_cast(expr).getPosition() == rank - 1; } -[[nodiscard]] int TypeHelper::getDataTypeValidSteps(VectorType type) { - auto typebits = type.getElementTypeBitWidth(); - const int favx512bits = 512; - const int favx2bits = 256; - if (HWInfo.favx512f) { - return generateValidSteps(favx512bits / typebits, type); - } else if (HWInfo.favx2) { - return generateValidSteps(favx2bits / typebits, type); - } else { - // invalid - LDBG("Please check the hardware information."); - assert(false && "Invalid hardware."); - return -1; - } +int TypeHelper::generateValidSteps(int steps, VectorType type) { + return type.getShape().back() >= steps + ? steps + : getNearestVectorStep(type.getShape().back()); } // Get the maximum number of current data types that a register can hold @@ -179,13 +192,18 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { } else if (HWInfo.favx2) { return favx2bits / typebits; } else { - // invalid + // invalid hardware LDBG("Please check the hardware information."); assert(false && "Invalid hardware."); return -1; } } +/// Get a appropriate for loop step for current vector type +[[nodiscard]] int TypeHelper::getDataTypeValidSteps(VectorType type) { + return generateValidSteps(getDataTypeMAXSIMDLength(type), type); +} + /// get float or integer dense attribute /// \param [in,out] attr template @@ -196,27 +214,25 @@ void getConstantDenseAttr(TypedAttr &attr, VectorType type, attr = T::get(type, denseAttr.getSplatValue()); } +/// Create a new arith constant operation according to the dense element attr FailureOr createArithSplatConstantOp(IRRewriter &rewriter, const Location &loc, DenseElementsAttr valueType, VectorType newOperandType) { - - if (valueType.isSplat()) { - TypedAttr attr; - if (isa(newOperandType.getElementType())) { - getConstantDenseAttr(attr, newOperandType, - valueType); - } else { - getConstantDenseAttr(attr, newOperandType, - valueType); - } - return rewriter.create(loc, attr)->getResults()[0]; + if (not valueType.isSplat()) { + return failure(); } - return failure(); + TypedAttr attr; + if (isa(newOperandType.getElementType())) { + getConstantDenseAttr(attr, newOperandType, valueType); + } else { + getConstantDenseAttr(attr, newOperandType, valueType); + } + return rewriter.create(loc, attr)->getResults()[0]; } -/// get operation vector type +/// Get vector type of the operation \param op /// \param isPrevOp whether the operation is a previous operation, if it is not /// prev-op, may need to use result vectortype /// default will return the opeation result type @@ -1053,14 +1069,19 @@ void ForLoopGenerator::getResultInCurrentOps( /// \param nextAnchorArgsIdxMap anchor args index map /// \param nextOriginalOperandMap original value to next loop args map /// \param nextOperandOriginalMap next loop args to original value map -void updateCurrentArgsStatus(const ValueRange &loopState, - const size_t loopStateIdx, +void updateCurrentArgsStatus(ValueRange loopState, const size_t loopStateIdx, SmallVector &nextAnchorArgs, Value originalValue, DenseMap &nextAnchorArgsIdxMap, DenseMap &nextOriginalOperandMap, DenseMap &nextOperandOriginalMap) { Value currentArgs = loopState[loopStateIdx]; + if (currentArgs.getType() != originalValue.getType()) { + llvm::outs() << loopStateIdx << "," + << "\n"; + currentArgs.dump(); + llvm::llvm_unreachable_internal("Type not equal."); + } nextAnchorArgs.emplace_back(currentArgs); nextAnchorArgsIdxMap[currentArgs] = nextAnchorArgs.size() - 1; nextOriginalOperandMap[originalValue] = currentArgs; @@ -1082,7 +1103,7 @@ void ForLoopGenerator::getInitArgsToNextAnchor( DenseSet visited; // find the next anchor arguments std::queue tmpQ(nextOperations); - DenseMap nextOriginalArgsMap, nextOperandOriginalMap; + DenseMap nextOriginalOperandMap, nextOperandOriginalMap; while (!tmpQ.empty()) { Operation *cur = tmpQ.front(); @@ -1091,17 +1112,18 @@ void ForLoopGenerator::getInitArgsToNextAnchor( for (auto x : curOperands) { if (!visited.contains(x) and opInitArgs.contains(x) and opAnchorPos[cur] > anchorIdx) { + assert(originalOperandLoopArgsMap.contains(x)); int loopStateIdx = currentLoopStateIdxMap[originalOperandLoopArgsMap[x]]; updateCurrentArgsStatus(loopState, loopStateIdx, nextAnchorArgs, x, - nextAnchorArgsIdxMap, nextOriginalArgsMap, + nextAnchorArgsIdxMap, nextOriginalOperandMap, nextOperandOriginalMap); visited.insert(x); } } } - originalOperandLoopArgsMap = nextOriginalArgsMap; - loopArgsOriginalOperandMap = nextOperandOriginalMap; + originalOperandLoopArgsMap = std::move(nextOriginalOperandMap); + loopArgsOriginalOperandMap = std::move(nextOperandOriginalMap); } void ForLoopGenerator::getOperationInCurrentAnchor( @@ -1158,7 +1180,7 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( std::queue &candidateQueue, std::queue &movedQueue, DenseMap &originalOperandLoopArgsMap, - DenseMap &LoopArgsoriginalOperandMap, + DenseMap &LoopArgsOriginalOperandMap, DenseMap> &indiceLoopMap) { // 1. get operations in current anchor position @@ -1177,7 +1199,7 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( getInitArgsToNextAnchor(anchorIdx, groupIdx, candidateQueue, loopState, currentLoopStateIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, originalOperandLoopArgsMap, - LoopArgsoriginalOperandMap); + LoopArgsOriginalOperandMap); // 5. move operations to moved queue while (!movingOperation.empty()) { @@ -1226,7 +1248,9 @@ void ForLoopGenerator::generateLoopResults( const size_t groupIdx, SmallVector &nextAnchorResults, DenseMap &nextAnchorResultsIdxMap, const ValueRange &forResults, const std::queue &movedOperation, - DenseMap &forResultOrignalResultMap) { + DenseMap &forResultOrignalResultMap, ValueRange loopState, + DenseMap ¤tOperandOriginMap, + DenseMap &nextOperandIdxMap) { SmallVector results; DenseMap currentResultMap; getResultInCurrentOps(anchorIdx, groupIdx, movedOperation, results, @@ -1250,10 +1274,15 @@ void ForLoopGenerator::generateLoopResults( nextAnchorResults.clear(); nextAnchorResultsIdxMap.clear(); - for (Value &result : results) { - nextAnchorResults.emplace_back(result); - nextAnchorResultsIdxMap[result] = nextAnchorResults.size() - 1; + // reduction operation due to special process results size will be zero + if (results.size() > 0) { + for (Value x : loopState) { + nextAnchorResults.emplace_back(results[nextOperandIdxMap[x]]); + nextAnchorResultsIdxMap[results[nextOperandIdxMap[x]]] = + nextAnchorResults.size() - 1; + } } + forResultOrignalResultMap = std::move(currentResultMap); } @@ -1343,7 +1372,9 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( // 4. generate loop results generateLoopResults(b, loc, anchorIdx, groupIdx, nextAnchorResults, nextAnchorResultsIdxMap, nxtFor->getResults(), - movedOperation, forResultOrignalResultMap); + movedOperation, forResultOrignalResultMap, + loopState, currentArgsOriginalMap, + nextAnchorArgsIdxMap); // reduction must return accumulate if (originalResultForResultMap.contains( @@ -1395,7 +1426,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( int accValIdx = currentLoopStateIdxMap [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; // check acc val is the first args - assert(accValIdx == 0); + // assert(accValIdx == 0); Value reductionResult = makeArithReduction( b, loc, multireductionOp.getKind(), multireductionOp.getSource(), @@ -1580,7 +1611,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( // 4. generate loop results generateLoopResults(b, loc, parallelIdx, groupIdx, nextAnchorResults, nextAnchorResultsIdxMap, nxtFor->getResults(), - movedQueue, forResultOrignalResultMap); + movedQueue, forResultOrignalResultMap, loopState, + currentOperandOriginalMap, nextAnchorArgsIdxMap); maybeYieldValue(b, loc, nextAnchorResults); } else if (parallelIdx == parallelAxis.size()) { @@ -1669,6 +1701,11 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( nextAnchorResults = std::move(currentAnchorResults); forResultOrignalResultMap = std::move(currentResultMap); nextAnchorResultsIdxMap = std::move(currentResultIdxMap); + // std::cout << "next anchor results : " << nextAnchorResults.size() + // << std::endl; + // for (auto x : nextAnchorResults) { + // x.getDefiningOp()->dump(); + // } maybeYieldValue(b, loc, nextAnchorResults); } }); @@ -1833,13 +1870,22 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, loopArgsOriginalOperandMap, forResultOrignalResultMap, indiceLoopMap); auto replaceIfFn = [&](OpOperand &use) { - return use.getOwner()->getBlock() == forOp->getBlock(); + auto walkResult = forOp->walk([&](Operation *op) { + if (use.getOwner() == op) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return walkResult != WalkResult::interrupt(); }; for (auto x : nextAnchorResults) { auto originalResult = forResultOrignalResultMap[x]; rewriter.replaceOpUsesWithIf( originalResult.getDefiningOp(), forOp->getResults()[nextAnchorResultsIdxMap[x]], replaceIfFn); + // following group must use the replaced result as operand + rectifyGroupOperands(grpIdx, originalResult, + forOp->getResults()[nextAnchorResultsIdxMap[x]]); } rewriter.eraseOp(multiReductionOp); @@ -1925,6 +1971,8 @@ scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( VectorType loopType = sourceType.getRank() > destType.getRank() ? sourceType : destType; size_t rank = loopType.getRank(); + DenseMap &opIndexMap = + getFusionStrategy().getOpGroupIndexMap(); auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); bool isLastDim = loopType.getRank() - 1 == (int64_t)forDimIdx; @@ -1947,11 +1995,11 @@ scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( Value source = scOp->getOperand(0); auto readSourceOp = cast(source.getDefiningOp()); - vector::TransferWriteOp successorWriteOp; + SmallVector successorWriteOps; for (Operation *x : scOp->getUsers()) { - if (isa(x)) { - successorWriteOp = cast(x); - break; + if (isa(x) and opIndexMap.contains(x) and + opIndexMap[x] == opIndexMap[scOp]) { + successorWriteOps.emplace_back(cast(x)); } } SmallVector exprs(loopType.getRank(), AffineExpr()); @@ -2034,12 +2082,15 @@ scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( SmallVector writeVars = loopType == sourceType ? smallRankShapeVars : inductionVars; - - rectifyWriteOperationIndice(&successorWriteOp, writeVars); - auto writeOp = b.create( - loc, transferReadOp->getResults()[0], loopState[0], writeVars, - inBoundsVal); - maybeYieldValue(b, loc, writeOp->getResults()); + SmallVector writeResults; + for (auto successorWriteOp : successorWriteOps) { + rectifyWriteOperationIndice(&successorWriteOp, writeVars); + auto writeOp = b.create( + loc, transferReadOp->getResults()[0], loopState[0], writeVars, + inBoundsVal); + writeResults.emplace_back(writeOp->getResults()[0]); + } + maybeYieldValue(b, loc, writeResults); } else { // outter loop auto nxtFor = generateShapeCastReadWriteLoop( @@ -2060,7 +2111,7 @@ void ForLoopGenerator::rectifyWriteOperationIndice( Operation::operand_range writeIndices = originalWriteOp->getIndices(); for (size_t i = 0; i < inMutableIdx; i++) { - writeVars[i] = (writeIndices[i]); + writeVars[i] = writeIndices[i]; } } @@ -2096,17 +2147,21 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { VectorType sourceType = scOp.getSourceVectorType(); VectorType destType = scOp.getResultVectorType(); + DenseMap &opIndexMap = + getFusionStrategy().getOpGroupIndexMap(); OpBuilder b(scOp); SmallVector iterArgs; - vector::TransferWriteOp successorWriteOp; + SmallVector successorWriteOps; for (Operation *x : scOp->getUsers()) { - if (isa(x)) { - successorWriteOp = cast(x); - break; + if (isa(x) and opIndexMap.contains(x) and + opIndexMap[x] == opIndexMap[x]) { + successorWriteOps.emplace_back(cast(x)); } } - iterArgs.emplace_back(successorWriteOp->getOperands()[1]); + for (auto successorWriteOp : successorWriteOps) { + iterArgs.emplace_back(successorWriteOp->getOperands()[1]); + } SmallVector inductionVars; IRRewriter rewriter(func); const size_t groupStep = getFusionStrategy().getGroupMaxSteps()[grpIdx]; @@ -2120,7 +2175,10 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { scCanonicalizer.isReadWriteOnLastDim()) { scf::ForOp forOp = generateShapeCastReadWriteLoop( b, grpIdx, 0, groupStep, scOp.getLoc(), inductionVars, iterArgs); - rewriter.replaceOp(successorWriteOp, forOp); + for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) { + rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); + } + rewriter.eraseOp(scOp); clearCurrentOperationGroup(grpIdx); return forOp; } @@ -2128,7 +2186,10 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { // scalar data movement scf::ForOp forOp = generateShapeCastReadWriteLoop( b, grpIdx, 0, 1, scOp.getLoc(), inductionVars, iterArgs); - rewriter.replaceOp(successorWriteOp, forOp); + for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) { + rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); + } + rewriter.eraseOp(scOp); clearCurrentOperationGroup(grpIdx); return forOp; } @@ -2223,13 +2284,21 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { resultIdxMap, forResultOrignalResultMap, indiceLoopMap); auto replaceIfFn = [&](OpOperand &use) { - return use.getOwner()->getBlock() == forOp->getBlock(); + auto walkResult = forOp->walk([&](Operation *op) { + if (use.getOwner() == op) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return walkResult != WalkResult::interrupt(); }; for (auto x : nextLoopResults) { auto originalResult = forResultOrignalResultMap[x]; rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), forOp->getResults()[resultIdxMap[x]], replaceIfFn); + rectifyGroupOperands(grpIdx, originalResult, + forOp->getResults()[resultIdxMap[x]]); } // clear current group operation clearCurrentOperationGroup(grpIdx); @@ -2480,6 +2549,8 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { } else if (isa(op)) { getShapeCastCanonicalizers().back().getCandidateOps().emplace_back( cast(op)); + llvm::outs() << " current shape cast op: "; + op->dump(); } } } @@ -2505,15 +2576,20 @@ void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { // generate MultiReduction for loops (void)generateMultiReductionForLoop(groupId); } + } + for (auto [groupId, tpCanonicalizer] : + llvm::enumerate(transposeCanonicalizers)) { SmallVector &transposeOps = - transposeCanonicalizers[groupId].getCandidateOps(); + tpCanonicalizer.getCandidateOps(); if (!transposeOps.empty()) { (void)generateTransposeForLoop(groupId); } - - SmallVector &shapeCastOps = - shapeCastCanonicalizers[groupId].getCandidateOps(); - if (!shapeCastOps.empty()) { + } + for (auto [groupId, scCanonicalizer] : + llvm::enumerate(shapeCastCanonicalizers)) { + SmallVector &scOps = + scCanonicalizer.getCandidateOps(); + if (!scOps.empty()) { (void)generateShapeCastForLoop(groupId); } } @@ -2559,13 +2635,17 @@ void CanonicalizerVectorOperation::run() { // Query groupResultYeildSet to map operaion result value to scf.yield // result value. - printGroupOps(getFusionStrategy().getOpGroups()); - std::cout << "___________ before analysis ________________" - << "\n"; + if (enableDebugPrinter) { + printGroupOps(getFusionStrategy().getOpGroups()); + llvm::outs() << "___________ before analysis ________________" + << "\n"; + } analysisGroupOperaion(); - std::cout << "___________ after analysis ________________" - << "\n"; - printGroupOps(getFusionStrategy().getOpGroups()); + if (enableDebugPrinter) { + llvm::outs() << "___________ after analysis ________________" + << "\n"; + printGroupOps(getFusionStrategy().getOpGroups()); + } // Speical Operation Canonicalization canonicalizeSpecialOperation(); @@ -2738,15 +2818,16 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( std::queue movedQueue; std::queue &opQueue = getFusionStrategy().getOpGroups()[groupIdx]; + movePreOpToCurrentAnchor( forDimIdx, groupIdx, b, inductionVars, loopState, operandIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, opQueue, movedQueue, originalOperandMap, operandOriginalMap, indiceLoopMap); auto nxtFor = constructNestedForOp( - forDimIdx + 1, groupIdx, b, loc, loopState, dims, inductionVars, - nextAnchorArgsIdxMap, originalOperandMap, operandOriginalMap, - nextAnchorResults, nextAnchorResultsIdxMap, + forDimIdx + 1, groupIdx, b, loc, nextAnchorArgs, dims, + inductionVars, nextAnchorArgsIdxMap, originalOperandMap, + operandOriginalMap, nextAnchorResults, nextAnchorResultsIdxMap, forResultOrignalResultMap, indiceLoopMap); movePostOpToCurrentAnchor( @@ -2757,7 +2838,8 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( generateLoopResults(b, loc, forDimIdx, groupIdx, nextAnchorResults, nextAnchorResultsIdxMap, nxtFor->getResults(), - movedQueue, forResultOrignalResultMap); + movedQueue, forResultOrignalResultMap, loopState, + currentOperandOriginalMap, nextAnchorArgsIdxMap); maybeYieldValue(b, loc, nextAnchorResults); } @@ -2807,17 +2889,6 @@ bool VectorFusionStrategy::isCompatibleVectorType(Operation *op1, auto sp1 = type1.value(); auto sp2 = type2.value(); - // if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { - // if (sp1.getRank() != sp2.getRank()) { - // return false; - // } - // for (long i = 0; i < sp1.getRank(); i++) { - // if (sp1.getDimSize(i) != sp2.getDimSize(i)) { - // return false; - // } - // } - // } - auto isCompatible = [](VectorType sp1, VectorType sp2) { bool isCompatible = true; auto min_rank = std::min(sp1.getRank(), sp2.getRank()); @@ -2971,6 +3042,40 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { }); } +bool readWriteDependency(Operation *op1, Operation *op2) { + if (not(isReadOrWriteOperation(op1) and isReadOrWriteOperation(op2))) { + return false; + } + auto readWriteOrder = [](Operation *op1, Operation *op2) { + if (isa(op1) and + isa(op2)) { + return true; + } + return false; + }; + if (!readWriteOrder(op1, op2) and !readWriteOrder(op2, op1)) { + return false; + } + + // e.g.: if op1 is read the value and pass it to op2, it is not data + // dependency + if (isOperationsHasDefUseRelation(op1, op2)) { + return false; + } + return true; +} + +static inline bool hasSameAxis(ArrayRef dims1, + ArrayRef dims2) { + DenseSet checkSet(dims2.begin(), dims2.end()); + for (auto x : dims1) { + if (checkSet.contains(x)) { + return true; + } + } + return false; +} + /// whether two operation has data dependency /// op1 default is previous operation, op2 default is current operation bool hasDataDependency(Operation *op1, Operation *op2) { @@ -2984,13 +3089,11 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return true; } - // if op1 is read the value and pass it to op2, it is not data dependency - if (isOperationsHasDefUseRelation(op1, op2)) { - return false; - } - // if op2 is write the result from op2, it is not data dependency - if (isOperationsHasDefUseRelation(op1, op2)) { - return false; + if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { + // if op1 is read the value and pass it to op2, it is not data dependency + if (isOperationsHasDefUseRelation(op1, op2)) { + return false; + } } // broadcast only fuse with post-op @@ -3006,23 +3109,16 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return hasDataDependency(op2, op1); } - auto hasSameAxis = [](const SmallVector &dims1, - const SmallVector &dims2) { - DenseSet checkSet(dims2.begin(), dims2.end()); - for (auto x : dims1) { - if (checkSet.contains(x)) { - return true; - } - } - return false; - }; auto res = TypeSwitch(op1) .Case([&](vector::ShapeCastOp shapeCastOp) { SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); - return hasSameAxis(dims1, dims2); + if (!isSpecialOp(op2)) { + return hasSameAxis(dims1, dims2); + } + return true; }) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { @@ -3085,7 +3181,6 @@ bool hasDataDependency(Operation *op1, Operation *op2) { getOperationDataAxis(op2, dims2); if (!isSpecialOp(op2)) { return hasSameAxis(dims1, dims2); - } else { } return true; }) @@ -3102,6 +3197,12 @@ bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { if (prevOp->getParentOp() != op->getParentOp()) { return true; } + + // read and write operation dependency + if (readWriteDependency(prevOp, op)) { + return true; + } + // special operation need to check data dependency axis if (hasDataDependency(prevOp, op)) { return true; @@ -3154,6 +3255,11 @@ void VectorFusionStrategy::classifyOperations() { func->walk([&](Operation *op) { if (filterOperation(op)) { addOperationToGroup(op); + } else if (isSpecialLinalgOp(op)) { + // following operation need a new group + if (opGroups.back().size() > 0) { + opGroups.emplace_back(std::queue()); + } } }); } @@ -3332,6 +3438,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( }) .Default([&](Operation *op) { if (isSpecialOp(op)) { + op->dump(); llvm::llvm_unreachable_internal( "It should not appear this operation."); return failure(); @@ -3595,44 +3702,46 @@ void VectorOperationAnalyzer::specialOperationRectify( } } -/// analysis operation result of current group whether needed by other -/// operation which out of current group -void VectorOperationAnalyzer::analysisGroupOperationResults() { - DenseMap> srcOpCanoniclizedMap; +void VectorOperationAnalyzer::updateReturnResultKind(Operation *sourceOp, + size_t sourceOpGid, + ReturnTypeKind rtKind) { + SmallVector>, 8> + &groupOpResults = getGroupOpResults(); + DenseMap &OpAnchorPos = + getFusionStrategy().getOpAnchorPos(); + + Value sourceResult = sourceOp->getResults()[0]; + if (srcOpCanoniclizedMap.contains(sourceOp)) { + sourceResult = srcOpCanoniclizedMap[sourceOp].second; + } + + size_t srcOpAnchor = groupOpResults[sourceOpGid][sourceResult].second; + ReturnTypeKind prevRtKind = groupOpResults[sourceOpGid][sourceResult].first; + srcOpAnchor = std::min(srcOpAnchor, OpAnchorPos[sourceOp]); + if (prevRtKind != rtKind) { + groupOpResults[sourceOpGid][sourceResult] = + std::make_pair(ReturnTypeKind::RT_Both, srcOpAnchor); + return; + } + if (rtKind == ReturnTypeKind::RT_InGroup) { + groupOpResults[sourceOpGid][sourceResult] = + std::make_pair(rtKind, srcOpAnchor); + } +} + +void VectorOperationAnalyzer::analysisGroupOperaion() { // record the operation which has been moved DenseSet movedOperationSet; - // record the operation's position which has visited, inorder to ensure set + // record the operation's visited order, inorder to ensure set // correct operand size_t opCounter = 0; DenseMap visitedOperation; DenseMap &opGroupIndexMap = getFusionStrategy().getOpGroupIndexMap(); SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); - SmallVector>, 8> - &groupOpResults = getGroupOpResults(); DenseMap &OpAnchorPos = getFusionStrategy().getOpAnchorPos(); - auto updateReturnResultKind = [&](Operation *sourceOp, size_t sourceOpGid, - ReturnTypeKind rtKind) { - Value sourceResult; - if (srcOpCanoniclizedMap.contains(sourceOp)) { - sourceResult = srcOpCanoniclizedMap[sourceOp].second; - } else { - sourceResult = sourceOp->getResults()[0]; - } - size_t srcOpAnchor = groupOpResults[sourceOpGid][sourceResult].second; - ReturnTypeKind prevRtKind = groupOpResults[sourceOpGid][sourceResult].first; - srcOpAnchor = std::min(srcOpAnchor, OpAnchorPos[sourceOp]); - if (prevRtKind != rtKind) { - groupOpResults[sourceOpGid][sourceResult] = - std::make_pair(ReturnTypeKind::RT_Both, srcOpAnchor); - } else if (rtKind == ReturnTypeKind::RT_InGroup) { - groupOpResults[sourceOpGid][sourceResult] = - std::make_pair(rtKind, srcOpAnchor); - } - }; - analysisGroupMaxSteps(); func.walk([&](Operation *op) { @@ -3782,9 +3891,27 @@ void VectorOperationAnalyzer::analysisGroupOperationResults() { LDBG("Complete analysis group operation results\n"); } -void VectorOperationAnalyzer::analysisGroupOperaion() { - // Results - analysisGroupOperationResults(); +void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, + Value originalResult, + Value forResult) { + size_t totalGroupSize = getFusionStrategy().getOpGroups().size(); + size_t startGroup = currentGroupId; + while (startGroup < totalGroupSize) { + SetVector &operandVector = getGroupOpInitArgs()[startGroup]; + if (operandVector.contains(originalResult)) { + SetVector replacedVector; + + for (auto v : operandVector) { + if (v == originalResult) { + replacedVector.insert(forResult); + } else { + replacedVector.insert(v); + } + } + getGroupOpInitArgs()[startGroup] = replacedVector; + } + startGroup++; + } } mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( @@ -3796,7 +3923,7 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( DenseMap operandIdxMap, resultIdxMap; DenseMap originalOperandMap, operandOriginalMap, forResultOrignalResultMap; - auto &initArgs = getGroupOpInitArgs()[groupId]; + SetVector &initArgs = getGroupOpInitArgs()[groupId]; for (Value x : initArgs) { operands.emplace_back(x); operandIdxMap[x] = operands.size() - 1; @@ -3813,13 +3940,23 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( inductionVars, operandIdxMap, originalOperandMap, operandOriginalMap, nextLoopResults, resultIdxMap, forResultOrignalResultMap, indiceLoopMap); auto replaceIfFn = [&](OpOperand &use) { - return use.getOwner()->getBlock() == forOp->getBlock(); + auto walkResult = forOp->walk([&](Operation *op) { + if (use.getOwner() == op) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return walkResult != WalkResult::interrupt(); }; + for (auto x : nextLoopResults) { auto originalResult = forResultOrignalResultMap[x]; rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), forOp->getResults()[resultIdxMap[x]], replaceIfFn); + // following group must use the replaced result as operand + rectifyGroupOperands(groupId, originalResult, + forOp->getResults()[resultIdxMap[x]]); } return forOp; @@ -3869,7 +4006,7 @@ void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { if (isGroupHasSpecialOperation(idx)) { return; } - auto &groupOpResults = getGroupOpResults(); + VectorType groupType = getFusionStrategy().getGroupBiggestRankVectorType()[idx]; IRRewriter rewriter(grp.back()); @@ -3882,8 +4019,6 @@ void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { if (failed(forOp)) { return; } - // 3 Update loop result uses - updateLoopResultUses(groupOpResults[idx], &forOp.value()); moveLoopInvariantCode(forOp.value()); } diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index e7fa461d7..fe404ca49 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -30,7 +30,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include namespace mlir { namespace gc { @@ -395,7 +394,8 @@ LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op) { LDBG("Attempting to vectorize:\n" << *op << "\n"); if (failed(lowerTargetOpPrecondition(op))) { - std::cout << "FAILED TO LOWER TARGET OP\n" << std::endl; + llvm::outs() << "FAILED TO LOWER TARGET OP\n" + << "\n"; LDBG("Vectorization pre-conditions failed\n"); return failure(); } @@ -585,15 +585,15 @@ struct LowerTileVectorPass // ensure read and write on last dimension vector::populateVectorTransferPermutationMapLoweringPatterns(secondPattern); // remove unnessary broadcast operation - // vector::populateSinkVectorBroadcastPatterns(secondPattern); + vector::populateSinkVectorBroadcastPatterns(secondPattern); // vector::TransferReadOp::getCanonicalizationPatterns(secondPattern, ctx); // vector::TransferWriteOp::getCanonicalizationPatterns(secondPattern, ctx); - // tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); + // tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(secondPattern); (void)applyPatternsAndFoldGreedily(funcOp, std::move(secondPattern)); - DominanceInfo domInfo; - IRRewriter rewriter(funcOp); - eliminateCommonSubExpressions(rewriter, domInfo, funcOp); + // DominanceInfo domInfo; + // IRRewriter rewriter(funcOp); + // eliminateCommonSubExpressions(rewriter, domInfo, funcOp); } }; } // namespace diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 9a2805668..27c2589fa 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -55,12 +55,12 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // scf + arith + math + vector + tensor + linalg.brgemm void populateVectorPasses(mlir::OpPassManager &pm) { - pm.addPass(createLowerToTileVector()); + pm.addNestedPass(createLowerToTileVector()); // Do promotion for math / arith ops pm.addNestedPass(math::createMathLegalizeToF32()); // sourceTypeStrs can be extended arith::ArithEmulateUnsupportedFloatsOptions options; - std::array typeStr = {"bf16"}; + std::array typeStr{"bf16"}; options.sourceTypeStrs = typeStr; options.targetTypeStr = "f32"; pm.addNestedPass( From 51619213d2498c8404aa71ad95771fca3ca8075d Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 16 Aug 2024 08:23:31 +0800 Subject: [PATCH 24/66] update lowr tile vector code --- lib/gc/Transforms/LowerTileVectorPass.cpp | 573 ++++++++++++---------- 1 file changed, 324 insertions(+), 249 deletions(-) diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp index fe404ca49..4976f2ff4 100644 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ b/lib/gc/Transforms/LowerTileVectorPass.cpp @@ -1,5 +1,4 @@ -//===- LowerTileVectorPass.cpp.cpp - OneDNNGraph To Linalg -// Lowering -*- C++ -*-===// +//===-- LowerTileVectorPass.cpp - Lower Op to vector ------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,28 +7,32 @@ //===----------------------------------------------------------------------===// #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include namespace mlir { namespace gc { @@ -40,7 +43,8 @@ namespace { #define DEBUG_TYPE "lower-to-tile-vector-pass" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") +#define SAFE_EXPAND(X) X +#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") #define IMPLEMENTED_MATMUL \ linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp, \ @@ -51,118 +55,199 @@ namespace { linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp #define SUPPORT_TENSOR_OP \ - tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::BitcastOp, \ - tensor::ConcatOp - -bool is_innermost_ir(Operation *op) { - bool inner_most = true; - op->walk([&inner_most](Operation *p) { - if (llvm::isa(p)) { - inner_most = false; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return inner_most; + tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::ConcatOp + +template +struct decay_equiv : std::is_same::type, U>::type {}; + +static inline bool isRequiredTensorOp(Operation *operation) { + return isa(operation); } -static bool isMatchedOperationUsage(Operation *op) { +/// matmul operation or fill + matmul operation +static bool isMatchedOperationSet(Operation *op) { if (isa(op)) { return true; } - // operation produce for matmul can't lower + // Operation produce for matmul can't lower. + // Currently only the fill operation need to check this. if (!isa(op)) { return false; } - for (auto x : op->getUsers()) { - if (isa(x)) { - return true; - } - } + return llvm::any_of(op->getUsers(), + [](Operation *x) { return isa(x); }); +} - return false; +static bool isContainsDynamicSize(ArrayRef sizes) { + return llvm::any_of(sizes, + [](int64_t x) { return x == ShapedType::kDynamic; }); } -/// Need to check if the reassociation are static/constant. -LogicalResult lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp) { - // - auto outputShape = expandOp.getStaticOutputShape(); - if (llvm::any_of(outputShape, - [](int64_t x) { return x == ShapedType::kDynamic; })) { - LDBG("Output shape must be static: " << expandOp << "\n"); - return failure(); +/// Reshape operation like expand_shape process helper class. +/// Inorder to avoid pass too many parameters to function. +struct ReshapeVectorizeHelper { + SmallVector srcVectorizedShape; + llvm::SmallDenseMap shapeScales; + SmallVector resultShape; + SmallVector srcShape; + + ReshapeVectorizeHelper() = default; + ReshapeVectorizeHelper(ArrayRef srcVectorizedShape, + llvm::SmallDenseMap &shapeScales, + ArrayRef resultShape, + ArrayRef srcShape) + : srcVectorizedShape(srcVectorizedShape), shapeScales(shapeScales), + resultShape(resultShape), srcShape(srcShape) {} + + /// Get the magnification factor of dimension size of the shape + void getScalesDim(ArrayRef inputVectorSizes) { + for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) { + if (vs != resultShape[idx]) + shapeScales[idx] = vs / resultShape[idx]; + } } +}; - return success(); -} +/// Get proper input vector size for the operation. +/// Currently only expandshape and collaspeshape need to handle this. +template ::value || + decay_equiv::value, + T>> +void getReshapeOperationVectorizeShape(ReshapeVectorizeHelper &reshapeHelper) { + reshapeHelper.srcVectorizedShape.clear(); + bool isCollapseOp = decay_equiv::value; + int64_t cur = 1, resultIdx = 0; + + for (auto [srcIdx, ss] : llvm::enumerate(reshapeHelper.srcShape)) { + cur *= ss; + if (isCollapseOp) { + reshapeHelper.srcVectorizedShape.emplace_back(ss); + } -LogicalResult lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp) { - if (bitCastOp.getSource().getType().getNumDynamicDims()) { - LDBG("Type must be static: " << bitCastOp << "\n"); - return failure(); + if (cur != reshapeHelper.resultShape[resultIdx]) { + continue; + } + + if (!isCollapseOp) { + reshapeHelper.srcVectorizedShape.emplace_back(cur); + } + + if (isCollapseOp and reshapeHelper.shapeScales.count(resultIdx)) { + reshapeHelper.srcVectorizedShape.back() *= + reshapeHelper.shapeScales[resultIdx]; + } + + if (!isCollapseOp and reshapeHelper.shapeScales.count(srcIdx)) { + reshapeHelper.srcVectorizedShape.back() *= + reshapeHelper.shapeScales[srcIdx]; + } + + cur = 1; + resultIdx++; } - return success(); } -/// Need to check if the reassociation are static/constant. +/// Need to check whether the reassociation, input, output and input vectorize +/// size are valid. +template ::value || + decay_equiv::value, + T>> LogicalResult -lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp) { - auto isShapeStatic = [](Value v) { - auto type = mlir::dyn_cast(v.getType()); - if (!type) { - LDBG("Operation type error: " << v << "\n"); - return false; - } - return type.hasStaticShape(); - }; - if (!isShapeStatic(collapseOp->getResults()[0])) { - LDBG("Output shape must be static: " << collapseOp << "\n"); +lowerReshapeOpPrecondition(T reshapeOp, + ArrayRef inputVectorSizes = {}) { + + Type resultType = reshapeOp->getResultTypes()[0]; + auto resultShapeType = cast(resultType); + RankedTensorType srcShapeType = reshapeOp.getSrcType(); + + // check reassociation + SmallVector associateIndices; + + for (const Attribute &attr : reshapeOp.getReassociation()) { + llvm::transform( + cast(attr), std::back_inserter(associateIndices), + [](Attribute indice) { return cast(indice).getInt(); }); + } + if (isContainsDynamicSize(associateIndices)) { + LDBG("Reassociation must be static: " << reshapeOp << "\n"); return failure(); } - if (!isShapeStatic(collapseOp.getSrc())) { - LDBG("Input shape must be static: " << collapseOp << "\n"); + + // check input and output shape + bool isStaticInputOutput = + resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape(); + if (!isStaticInputOutput) { + LDBG("Input and output shape must be static: " << reshapeOp << "\n"); + return failure(); + } + // ensure specify input vector size is valid + if (!inputVectorSizes.empty() && + failed(vector::isValidMaskedInputVector(resultShapeType.getShape(), + inputVectorSizes))) + return failure(); + + if (!llvm::all_of(llvm::zip(resultShapeType.getShape(), inputVectorSizes), + [](std::tuple sizePair) { + int64_t staticSize = std::get<0>(sizePair); + int64_t inputSize = std::get<1>(sizePair); + return inputSize % staticSize == 0; + })) { + LDBG("Input vector sizes must be an integer multiple or equal to " + "space static sizes"); return failure(); } return success(); } -LogicalResult lowerConcatOpPrecondition(tensor::ConcatOp concatOp) { - for (auto x : concatOp->getOperands()) { - auto tensorType = mlir::dyn_cast(x.getType()); - if (!tensorType) { - LDBG("Operation type error: " << concatOp << "\n"); - return failure(); - } - if (tensorType.getNumDynamicDims()) { - LDBG("Type must be static: " << concatOp << "\n"); - return failure(); - } +LogicalResult +lowerConcatOpPrecondition(tensor::ConcatOp concatOp, + ArrayRef inputVectorSizes = {}) { + if (!inputVectorSizes.empty()) { + LDBG("Concat operation does not support specify inputVectorSizes: " + << concatOp << "\n"); + } + // check input operand shape type + if (not llvm::all_of(concatOp.getOperandTypes(), [](Type x) { + return cast(x).hasStaticShape(); + })) { + LDBG("Type must be static: " << concatOp << "\n"); + return failure(); + } + // check valid dimension + uint64_t dim = concatOp.getDim(); + if (dim >= (uint64_t)concatOp.getResultType().getRank()) { + LDBG("Invalid dim: " << concatOp << "\n"); + return failure(); } return success(); } -LogicalResult lowerTargetOpPrecondition(Operation *op) { +LogicalResult lowerTargetOpPrecondition(Operation *op, + ArrayRef inputVectorSizes) { return TypeSwitch(op) .Case([&](auto expandShapeOp) { - return lowerExpandOpPrecondition(expandShapeOp); + return lowerReshapeOpPrecondition( + expandShapeOp, inputVectorSizes); }) .Case([&](auto collapseShapeOp) { - return lowerCollapseShapeOpPrecondition(collapseShapeOp); + return lowerReshapeOpPrecondition( + collapseShapeOp, inputVectorSizes); + }) + .Case([&](auto concatOp) { + return lowerConcatOpPrecondition(concatOp, inputVectorSizes); }) - .Case( - [&](auto bitCastOp) { return lowerBitcastOpPrecondition(bitCastOp); }) - .Case( - [&](auto concatOp) { return lowerConcatOpPrecondition(concatOp); }) .Default([](auto) { return failure(); }); } Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, - SmallVector destSizes, + ArrayRef destSizes, ArrayRef inputVectorSizes, bool useInBoundsInsteadOfMasking) { @@ -242,8 +327,7 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, tensor::getMixedSizes(builder, loc, source); Value mask = builder.create(loc, maskType, mixedSourceDims); - return mlir::vector::maskOperation(builder, transferReadOp, mask) - ->getResult(0); + return vector::maskOperation(builder, transferReadOp, mask)->getResult(0); } /// Vectorize a `tensor::expandshape` to these 3 Ops: @@ -252,77 +336,54 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, /// vector::TransferWriteOp. - Write the result vector back to the destination /// tensor template -LogicalResult lowerTensorExpandShapeOp(RewriterBase &rewriter, - Operation *inputOp, - SmallVectorImpl &newResults) { +LogicalResult lowerTensorReshapeOp(RewriterBase &rewriter, Operation *inputOp, + SmallVectorImpl &newResults, + ArrayRef inputVectorSizes) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(inputOp); - auto src = inputOp->getOperand(0); - auto srcType = mlir::dyn_cast(src.getType()); - auto result = inputOp->getResults()[0]; - auto resultType = mlir::dyn_cast(result.getType()); + auto src = inputOp->getOperand(0); + auto srcType = cast(src.getType()); + OpResult result = inputOp->getResults()[0]; + auto resultType = cast(result.getType()); ArrayRef resultShape = resultType.getShape(); + ArrayRef srcShape = srcType.getShape(); Location loc = inputOp->getLoc(); - // read + SmallVector srcVectorizedShape(srcType.getRank()); + llvm::SmallDenseMap shapeScales; + ReshapeVectorizeHelper reshapeHelper(srcVectorizedShape, shapeScales, + resultShape, srcShape); + + srcVectorizedShape.assign(srcShape.begin(), srcShape.end()); + if (!inputVectorSizes.empty()) { + reshapeHelper.getScalesDim(inputVectorSizes); + getReshapeOperationVectorizeShape(reshapeHelper); + } + // generate read operation auto padValue = rewriter.create( loc, rewriter.getZeroAttr(srcType.getElementType())); - Value readResult = createReadOrMaskedRead( - rewriter, loc, src, srcType.getShape(), padValue, false); + Value readResult = vector::createReadOrMaskedRead( + rewriter, loc, src, + inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape, + padValue, true); auto shapeCastType = - VectorType::get(resultType.getShape(), resultType.getElementType()); + VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes, + resultType.getElementType()); vector::ShapeCastOp shapeCastOp = rewriter.create(loc, shapeCastType, readResult); - // write - SmallVector destSizes; - for (auto size : resultShape) { - destSizes.emplace_back(rewriter.getIndexAttr(size)); - } - Operation *write = - createWriteOrMaskedWrite(rewriter, loc, shapeCastOp->getResults()[0], - destSizes, resultShape, false); - newResults.push_back(write->getResult(0)); - return success(); -} - -/// Vectorize a `tensor::bitcast` to these 3 Ops: -/// vector::TransferReadOp - Reads a vector from the source tensor -/// vector.Bitcast - Bitcast the data based on the target. -/// vector::TransferWriteOp. - Write the result vector back to the destination -/// tensor -LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, - tensor::BitcastOp bitCastOp, - SmallVectorImpl &newResults) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(bitCastOp); - - auto sourceType = bitCastOp.getSource().getType(); - auto resultType = bitCastOp.getResult().getType(); - auto resultShape = resultType.getShape(); - Location loc = bitCastOp->getLoc(); + // generate write operation + SmallVector destSizes(resultShape.size()); + llvm::transform(resultShape, std::begin(destSizes), [&rewriter](size_t size) { + return rewriter.getIndexAttr(size); + }); - auto padValue = rewriter.create( - loc, rewriter.getZeroAttr(sourceType.getElementType())); - Value readResult = createReadOrMaskedRead( - rewriter, loc, bitCastOp.getSource(), resultShape, padValue, false); - - auto resultVectorType = - VectorType::get(resultShape, resultType.getElementType()); - vector::BitCastOp vectorbitCastOp = - rewriter.create(loc, resultVectorType, readResult); - - SmallVector writeMaskShape( - vectorbitCastOp.getResultVectorType().getShape()); - llvm::SmallVector destSizes; - for (auto size : resultShape) - destSizes.emplace_back(rewriter.getIndexAttr(size)); - auto write = - createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0), - destSizes, resultShape, false); - newResults.push_back(write->getResults()[0]); + Operation *write = createWriteOrMaskedWrite( + rewriter, loc, shapeCastOp->getResults()[0], destSizes, + inputVectorSizes.empty() ? resultShape : inputVectorSizes, true); + newResults.push_back(write->getResult(0)); return success(); } @@ -354,8 +415,7 @@ LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, rewriter.create(loc, rewriter.getIndexAttr(dim)); int64_t rank = concatOp.getResultType().getRank(); - auto srcType = - mlir::dyn_cast(concatOp->getResultTypes()[0]); + auto srcType = cast(concatOp->getResultTypes()[0]); auto padValue = rewriter.create( loc, rewriter.getZeroAttr(srcType.getElementType())); @@ -366,15 +426,14 @@ LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, SmallVector sizes = tensor::getMixedSizes(rewriter, loc, input); - SmallVector readMaskShape; - auto inputType = mlir::dyn_cast(input.getType()); - auto sourceShape = inputType.getShape(); + auto inputType = cast(input.getType()); + + Value readResult = createReadOrMaskedRead( + rewriter, loc, input, inputType.getShape(), padValue, true); - readMaskShape.append(sourceShape.begin(), sourceShape.end()); - Value readResult = createReadOrMaskedRead(rewriter, loc, input, sourceShape, - padValue, false); Value zero = rewriter.create(loc, 0); SmallVector indices(rank, zero); + // update write position indices[dim] = previous_offset; result = rewriter .create( @@ -390,12 +449,13 @@ LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, return success(); } -LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op) { +LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op, + ArrayRef inputVectorSizes = {}) { LDBG("Attempting to vectorize:\n" << *op << "\n"); + LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); - if (failed(lowerTargetOpPrecondition(op))) { - llvm::outs() << "FAILED TO LOWER TARGET OP\n" - << "\n"; + if (failed(lowerTargetOpPrecondition(op, inputVectorSizes))) { LDBG("Vectorization pre-conditions failed\n"); return failure(); } @@ -404,15 +464,12 @@ LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op) { auto lowerResult = TypeSwitch(op) .Case([&](auto expandShapeOp) { - return lowerTensorExpandShapeOp( - rewriter, expandShapeOp, results); + return lowerTensorReshapeOp( + rewriter, expandShapeOp, results, inputVectorSizes); }) .Case([&](auto collapseShapeOp) { - return lowerTensorExpandShapeOp( - rewriter, collapseShapeOp, results); - }) - .Case([&](auto bitCastOp) { - return lowerTensorBitcastOp(rewriter, bitCastOp, results); + return lowerTensorReshapeOp( + rewriter, collapseShapeOp, results, inputVectorSizes); }) .Case([&](auto concatOp) { return lowerTensorConcatOp(rewriter, concatOp, results); @@ -432,43 +489,59 @@ LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op) { return success(); } -bool is_required_tensorOp(Operation *operation) { - return isa(operation); -} - template struct OperationConvertTileVectorPass : public RewritePattern { - explicit OperationConvertTileVectorPass(MLIRContext *context, - bool vectorizeNDExtract = false, - bool flatten1DDepthwiseConv = false) +private: + /// specify vectorize size + SmallVector inputVectorSizes; + /// keep those parameters for future use + bool vectorizeNDExtract, flatten1DDepthwiseConv; + +public: + explicit OperationConvertTileVectorPass( + MLIRContext *context, ArrayRef inputVectorSizes = {}, + bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + inputVectorSizes(inputVectorSizes), vectorizeNDExtract(vectorizeNDExtract), flatten1DDepthwiseConv(flatten1DDepthwiseConv) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto targetOp = llvm::dyn_cast(op); - if (!targetOp || !is_innermost_ir(op)) + auto targetOp = dyn_cast(op); + if (!targetOp) return rewriter.notifyMatchFailure(op, "Not expected operations."); // linalg.fill + linalgx.batch_mutmul should not be lower to vector - // because these two operation is needed by brgemm optimization. - if (isMatchedOperationUsage(op)) { + // because these two operation is needed by brgemm kernel. + if (isMatchedOperationSet(op)) { return rewriter.notifyMatchFailure( - op, "linalg.fill + linalgx.batch_matmul can't do lowering."); + op, "linalg.fill + linalg.matmul can't do lowering."); } - - return linalg::vectorize(rewriter, op, /*inputVectorSizes=*/{}, - /*scalableVecDims=*/{}, vectorizeNDExtract, - flatten1DDepthwiseConv); + SmallVector scalableVecDims(inputVectorSizes.size(), false); + if (failed(linalg::vectorize(rewriter, op, + /*inputVectorSizes=*/inputVectorSizes, + /*inputScalableVecDims=*/scalableVecDims, + vectorizeNDExtract, flatten1DDepthwiseConv))) { + return rewriter.notifyMatchFailure(op, "Fail to vectorize."); + } + return success(); } - -private: - bool vectorizeNDExtract, flatten1DDepthwiseConv; }; +/// Lower tensor.unpack operation to vector. +/// +/// The reason why we don't use `OperationConvertTileVectorPass` is we +/// need to specify input vector size due to unpack operation does not support +/// empty vector size. It's logic is not consistent with other tensor operation. +/// It would be better we split this process logic as a standalone class to +/// notify unpack operation is not support empty vector size. We need to support +/// it like other operation in the future. +/// +/// TODO: Need to support upstream to handle empty vector size. Currently +/// upstream folks don't allow me to do this. It's weird, I can't find reason. struct TensorUnpackConvertVectorPass : public RewritePattern { explicit TensorUnpackConvertVectorPass(MLIRContext *context) @@ -476,69 +549,55 @@ struct TensorUnpackConvertVectorPass : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - tensor::UnPackOp tensorUnPackOp = dyn_cast(op); - - if (!tensorUnPackOp || !is_innermost_ir(op)) + auto tensorUnPackOp = dyn_cast(op); + if (!tensorUnPackOp) return rewriter.notifyMatchFailure(op, "Not expected operations."); - Value resultValue = op->getResult(0); - auto resultTy = dyn_cast(resultValue.getType()); - if (!resultTy) - return rewriter.notifyMatchFailure(op, "expected ranked tensor type"); - - llvm::ArrayRef inputShape = resultTy.getShape(); - std::vector targetVectorSizes = inputShape.vec(); - llvm::SmallVector targetVecDims(inputShape.size(), false); - return linalg::vectorize(rewriter, op, - /*inputVectorSizes=*/targetVectorSizes, - /*scalableVecDims=*/targetVecDims, false, false); + auto resultTy = cast(op->getResultTypes()[0]); + // TODO: Need to support upstream to handle empty vector size. Currently + // upstream folks don't allow me to do this. + ArrayRef inputShape = resultTy.getShape(); + SmallVector targetVecDims(inputShape.size(), false); + + if (failed(linalg::vectorize(rewriter, op, + /*inputVectorSizes=*/inputShape.vec(), + /*inputScalableVecDims=*/targetVecDims, false, + false))) { + return rewriter.notifyMatchFailure(op, "Fail to vectorize."); + } + return success(); } }; +/// Some tensor operation lowering to vector. +/// +/// Currently support expand_shape, collapse_shape and concat_shape. +/// May need support other operation in the future. struct TensorOpConvertVectorPass : public RewritePattern { +private: + SmallVector inputVectorSizes; +public: explicit TensorOpConvertVectorPass(MLIRContext *context, - bool vectorizeExtract = false, - bool flatten1DDepthwiseConv = false) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + ArrayRef inputVectorSizes = {}) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + inputVectorSizes(inputVectorSizes) {} + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - bool is_target = is_required_tensorOp(op); - if (!is_target || !is_innermost_ir(op)) + bool is_target = isRequiredTensorOp(op); + if (!is_target) return rewriter.notifyMatchFailure(op, "Not expected operations."); - return convert2TargetOperation(rewriter, op); - } -}; - -struct EliminateWriteReadOpPass - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::TransferReadOp op, - PatternRewriter &rewriter) const override { - auto sourceOp = op->getOperand(0).getDefiningOp(); - if (isa_and_nonnull(sourceOp)) { - rewriter.replaceOp(op, sourceOp->getOperand(0)); - return success(); + if (failed(convert2TargetOperation(rewriter, op, inputVectorSizes))) { + return rewriter.notifyMatchFailure(op, "Fail to vectorize."); } - return failure(); + return success(); } }; -void eliminateWriteReadOperation(Operation *op) { - if (!isa_and_nonnull(op)) { - return; - } - auto sourceOp = op->getOperand(0).getDefiningOp(); - if (isa_and_nonnull(sourceOp)) { - IRRewriter rewriter(op); - rewriter.replaceOp(op, sourceOp->getOperand(0)); - } -} - -/// Pass that lower to tile vector. +/// Patterns that lower to tile (virtual) vector. void populateLowerToTileVectorPatterns(RewritePatternSet &patterns) { patterns.add, OperationConvertTileVectorPass>( @@ -547,53 +606,69 @@ void populateLowerToTileVectorPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +/// LowerToTileVectorPass is a pass that lowers operations to tile (virtual) +/// vector. We must aware that this pass do not support dynamic shape currently. struct LowerTileVectorPass : public impl::LowerToTileVectorBase { void runOnOperation() final { // auto *ctx = &getContext(); - RewritePatternSet patterns(ctx); + RewritePatternSet patternsInit(ctx); auto funcOp = getOperation(); tensor::ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) { Operation *producer = fusedOperand->get().getDefiningOp(); return producer && producer->hasOneUse(); }; - // some operation convert as constant, this pattern can help us to improve - // the performance - // tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn); - // remove unnessary operation - // tensor::populateReassociativeReshapeFoldingPatterns(patterns); - // tensor::populateFoldTensorSubsetOpPatterns(patterns); - // tensor::populateFoldTensorEmptyPatterns(patterns, true); - // tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - populateLowerToTileVectorPatterns(patterns); - linalg::populatePadOpVectorizationPatterns(patterns); - - GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns), config); - // error case: - // due to insert slice tensor<1x32xf32> to tensor<1x128x1x32xf32> + // Some operation convert as constant, this pattern can help us to improve + // the performance. + tensor::populateRewriteAsConstantPatterns(patternsInit, defaultControlFn); + // Remove unnessary operation like extract slice and insert slice + tensor::populateReassociativeReshapeFoldingPatterns(patternsInit); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patternsInit); + tensor::populateFoldTensorSubsetOpPatterns(patternsInit); + // Pad operation will lower to linalg.fill. We lower it in init patterns + // then lower the fill operation in second patterns. + linalg::populatePadOpVectorizationPatterns(patternsInit); + + GreedyRewriteConfig configInit; + // Init patterns use to remove useless tensor operation like extract or + // insert slice. + configInit.strictMode = GreedyRewriteStrictness::ExistingOps; + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patternsInit), + configInit); + + RewritePatternSet firstPatterns(ctx); + // All the dynamic shape will reject to lower. + populateLowerToTileVectorPatterns(firstPatterns); + GreedyRewriteConfig configFirstPn; + // We only apply the lowering pattern on existing operations + configFirstPn.strictMode = GreedyRewriteStrictness::ExistingOps; + (void)applyPatternsAndFoldGreedily(funcOp, std::move(firstPatterns), + configFirstPn); + // Error case: + // ``` // linalg.copy : <1x32xf32> - // -> transfer_write : permutation map = (d0, d1, d2, d3) -> (d0, d3) - // Inorder to avoid the fold greedily bug (fold wrong permution map for the - // transfer_write operation). Give it the new full IR to fold second time - // can fold correctly. + // tensor.insert_slice tensor<1x32xf32> to tensor<1x128x1x32xf32> + // --> lowering as: + // transfer_write : permutation map = (d0, d1, d2, d3) -> (d0, d3) + // ``` + // Inorder to avoid the fold greedily bug (fold wrong + // permution map for the transfer_write operation). Give it the new full IR + // to fold second time can fold correctly. RewritePatternSet secondPattern(ctx); - // secondPattern.add(patterns.getContext()); - // ensure read and write on last dimension + // Ensure each operation has a clear semantics, rather than a composite + // semantics. Instead of leaving it to the subsequent passes to handle these + // complex semantics, it reduces the difficulty of handling operations in + // the subsequent passes. Like transfer_read and transfer_write may have + // transpose or braodcast semantic etc. vector::populateVectorTransferPermutationMapLoweringPatterns(secondPattern); - // remove unnessary broadcast operation + // Remove unnessary broadcast operation vector::populateSinkVectorBroadcastPatterns(secondPattern); - // vector::TransferReadOp::getCanonicalizationPatterns(secondPattern, ctx); - // vector::TransferWriteOp::getCanonicalizationPatterns(secondPattern, ctx); - // tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(secondPattern); - + // Second fold can help us to eliminate redundant operation like consecutive + // read and write. (void)applyPatternsAndFoldGreedily(funcOp, std::move(secondPattern)); - // DominanceInfo domInfo; - // IRRewriter rewriter(funcOp); - // eliminateCommonSubExpressions(rewriter, domInfo, funcOp); + // may need other patterns to reduce redundant operations } }; } // namespace From 747f63afa335b0649453deb2d38a0c12b56b782a Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 22 Aug 2024 09:01:45 +0800 Subject: [PATCH 25/66] remove lower tile part --- include/gc/Transforms/Passes.td | 15 - lib/gc/Transforms/LowerTileVectorPass.cpp | 680 ---------------------- 2 files changed, 695 deletions(-) delete mode 100644 lib/gc/Transforms/LowerTileVectorPass.cpp diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 5c8692d8b..21b571cca 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -106,21 +106,6 @@ def MergeNestedForall : Pass<"merge-nested-forall"> { let dependentDialects = ["scf::SCFDialect"]; } -def LowerToTileVector : Pass<"lower-to-tile-vector", "func::FuncOp"> { - let summary = "Lower tensor to tile vector."; - let description = [{ - Lower tensor to tile vector form. - }]; - let dependentDialects = [ - "::mlir::func::FuncDialect", - "::mlir::math::MathDialect", - "::mlir::arith::ArithDialect", - "::mlir::tensor::TensorDialect", - "::mlir::linalg::LinalgDialect", - "::mlir::vector::VectorDialect", - ]; -} - def CPUPhysicalRegisterPass : Pass<"CPU-physical-register-pass", "func::FuncOp"> { let summary = "Lower operation to cpu pysical register size."; let description = [{ diff --git a/lib/gc/Transforms/LowerTileVectorPass.cpp b/lib/gc/Transforms/LowerTileVectorPass.cpp deleted file mode 100644 index 4976f2ff4..000000000 --- a/lib/gc/Transforms/LowerTileVectorPass.cpp +++ /dev/null @@ -1,680 +0,0 @@ -//===-- LowerTileVectorPass.cpp - Lower Op to vector ------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#include "gc/Dialect/Linalgx/LinalgxOps.h" -#include "gc/Transforms/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" -#include - -namespace mlir { -namespace gc { - -#define GEN_PASS_DEF_LOWERTOTILEVECTOR -#include "gc/Transforms/Passes.h.inc" -namespace { -#define DEBUG_TYPE "lower-to-tile-vector-pass" - -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define SAFE_EXPAND(X) X -#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") - -#define IMPLEMENTED_MATMUL \ - linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp, \ - linalg::BatchReduceMatmulOp, linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp, \ - linalg::MatmulOp, linalg::BatchMatmulOp, \ - linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ - linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ - linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp - -#define SUPPORT_TENSOR_OP \ - tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::ConcatOp - -template -struct decay_equiv : std::is_same::type, U>::type {}; - -static inline bool isRequiredTensorOp(Operation *operation) { - return isa(operation); -} - -/// matmul operation or fill + matmul operation -static bool isMatchedOperationSet(Operation *op) { - if (isa(op)) { - return true; - } - // Operation produce for matmul can't lower. - // Currently only the fill operation need to check this. - if (!isa(op)) { - return false; - } - - return llvm::any_of(op->getUsers(), - [](Operation *x) { return isa(x); }); -} - -static bool isContainsDynamicSize(ArrayRef sizes) { - return llvm::any_of(sizes, - [](int64_t x) { return x == ShapedType::kDynamic; }); -} - -/// Reshape operation like expand_shape process helper class. -/// Inorder to avoid pass too many parameters to function. -struct ReshapeVectorizeHelper { - SmallVector srcVectorizedShape; - llvm::SmallDenseMap shapeScales; - SmallVector resultShape; - SmallVector srcShape; - - ReshapeVectorizeHelper() = default; - ReshapeVectorizeHelper(ArrayRef srcVectorizedShape, - llvm::SmallDenseMap &shapeScales, - ArrayRef resultShape, - ArrayRef srcShape) - : srcVectorizedShape(srcVectorizedShape), shapeScales(shapeScales), - resultShape(resultShape), srcShape(srcShape) {} - - /// Get the magnification factor of dimension size of the shape - void getScalesDim(ArrayRef inputVectorSizes) { - for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) { - if (vs != resultShape[idx]) - shapeScales[idx] = vs / resultShape[idx]; - } - } -}; - -/// Get proper input vector size for the operation. -/// Currently only expandshape and collaspeshape need to handle this. -template ::value || - decay_equiv::value, - T>> -void getReshapeOperationVectorizeShape(ReshapeVectorizeHelper &reshapeHelper) { - reshapeHelper.srcVectorizedShape.clear(); - bool isCollapseOp = decay_equiv::value; - int64_t cur = 1, resultIdx = 0; - - for (auto [srcIdx, ss] : llvm::enumerate(reshapeHelper.srcShape)) { - cur *= ss; - if (isCollapseOp) { - reshapeHelper.srcVectorizedShape.emplace_back(ss); - } - - if (cur != reshapeHelper.resultShape[resultIdx]) { - continue; - } - - if (!isCollapseOp) { - reshapeHelper.srcVectorizedShape.emplace_back(cur); - } - - if (isCollapseOp and reshapeHelper.shapeScales.count(resultIdx)) { - reshapeHelper.srcVectorizedShape.back() *= - reshapeHelper.shapeScales[resultIdx]; - } - - if (!isCollapseOp and reshapeHelper.shapeScales.count(srcIdx)) { - reshapeHelper.srcVectorizedShape.back() *= - reshapeHelper.shapeScales[srcIdx]; - } - - cur = 1; - resultIdx++; - } -} - -/// Need to check whether the reassociation, input, output and input vectorize -/// size are valid. -template ::value || - decay_equiv::value, - T>> -LogicalResult -lowerReshapeOpPrecondition(T reshapeOp, - ArrayRef inputVectorSizes = {}) { - - Type resultType = reshapeOp->getResultTypes()[0]; - auto resultShapeType = cast(resultType); - RankedTensorType srcShapeType = reshapeOp.getSrcType(); - - // check reassociation - SmallVector associateIndices; - - for (const Attribute &attr : reshapeOp.getReassociation()) { - llvm::transform( - cast(attr), std::back_inserter(associateIndices), - [](Attribute indice) { return cast(indice).getInt(); }); - } - if (isContainsDynamicSize(associateIndices)) { - LDBG("Reassociation must be static: " << reshapeOp << "\n"); - return failure(); - } - - // check input and output shape - bool isStaticInputOutput = - resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape(); - if (!isStaticInputOutput) { - LDBG("Input and output shape must be static: " << reshapeOp << "\n"); - return failure(); - } - // ensure specify input vector size is valid - if (!inputVectorSizes.empty() && - failed(vector::isValidMaskedInputVector(resultShapeType.getShape(), - inputVectorSizes))) - return failure(); - - if (!llvm::all_of(llvm::zip(resultShapeType.getShape(), inputVectorSizes), - [](std::tuple sizePair) { - int64_t staticSize = std::get<0>(sizePair); - int64_t inputSize = std::get<1>(sizePair); - return inputSize % staticSize == 0; - })) { - LDBG("Input vector sizes must be an integer multiple or equal to " - "space static sizes"); - return failure(); - } - - return success(); -} - -LogicalResult -lowerConcatOpPrecondition(tensor::ConcatOp concatOp, - ArrayRef inputVectorSizes = {}) { - if (!inputVectorSizes.empty()) { - LDBG("Concat operation does not support specify inputVectorSizes: " - << concatOp << "\n"); - } - // check input operand shape type - if (not llvm::all_of(concatOp.getOperandTypes(), [](Type x) { - return cast(x).hasStaticShape(); - })) { - LDBG("Type must be static: " << concatOp << "\n"); - return failure(); - } - // check valid dimension - uint64_t dim = concatOp.getDim(); - if (dim >= (uint64_t)concatOp.getResultType().getRank()) { - LDBG("Invalid dim: " << concatOp << "\n"); - return failure(); - } - - return success(); -} - -LogicalResult lowerTargetOpPrecondition(Operation *op, - ArrayRef inputVectorSizes) { - - return TypeSwitch(op) - .Case([&](auto expandShapeOp) { - return lowerReshapeOpPrecondition( - expandShapeOp, inputVectorSizes); - }) - .Case([&](auto collapseShapeOp) { - return lowerReshapeOpPrecondition( - collapseShapeOp, inputVectorSizes); - }) - .Case([&](auto concatOp) { - return lowerConcatOpPrecondition(concatOp, inputVectorSizes); - }) - .Default([](auto) { return failure(); }); -} - -Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, - Value input, - ArrayRef destSizes, - ArrayRef inputVectorSizes, - bool useInBoundsInsteadOfMasking) { - - auto inputType = cast(input.getType()); - Value dest = builder.create(loc, destSizes, - inputType.getElementType()); - int64_t rank = cast(dest.getType()).getRank(); - auto zero = builder.create(loc, 0); - auto destShape = cast(dest.getType()).getShape(); - SmallVector inBoundsVal(rank, true); - if (useInBoundsInsteadOfMasking) { - // Update the inBounds attribute. - for (unsigned i = 0; i < rank; i++) - inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) && - !ShapedType::isDynamic(destShape[i]); - } - Operation *write = builder.create( - loc, - /*vector=*/input, - /*source=*/dest, - /*indices=*/SmallVector(rank, zero), - /*inBounds=*/inBoundsVal); - assert(llvm::none_of( - destShape.drop_front(inputVectorSizes.size()), - [](int64_t size) { return size == ShapedType::kDynamic; }) && - "Only dims aligned with inputVectorSizes may be dynamic"); - if (useInBoundsInsteadOfMasking) - return write; - bool needMaskForWrite = !llvm::equal( - inputVectorSizes, destShape.take_front(inputVectorSizes.size())); - if (needMaskForWrite) { - SmallVector writeMaskShape; - writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end()); - writeMaskShape.append(destShape.begin() + inputVectorSizes.size(), - destShape.end()); - auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type()); - Value maskForWrite = - builder.create(loc, writeMaskType, destSizes); - write = mlir::vector::maskOperation(builder, write, maskForWrite); - } - return write; -} - -Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, - ArrayRef readShape, Value padValue, - bool useInBoundsInsteadOfMasking) { - assert(llvm::none_of(readShape, - [](int64_t s) { return s == ShapedType::kDynamic; }) && - "expected static shape"); - auto sourceShapedType = cast(source.getType()); - auto sourceShape = sourceShapedType.getShape(); - assert(sourceShape.size() == readShape.size() && "expected same ranks."); - auto maskType = VectorType::get(readShape, builder.getI1Type()); - auto vectorType = VectorType::get(readShape, padValue.getType()); - assert(padValue.getType() == sourceShapedType.getElementType() && - "expected same pad element type to match source element type"); - int64_t readRank = readShape.size(); - auto zero = builder.create(loc, 0); - SmallVector inBoundsVal(readRank, true); - if (useInBoundsInsteadOfMasking) { - // Update the inBounds attribute. - for (unsigned i = 0; i < readRank; i++) - inBoundsVal[i] = (sourceShape[i] == readShape[i]) && - !ShapedType::isDynamic(sourceShape[i]); - } - auto transferReadOp = builder.create( - loc, - /*vectorType=*/vectorType, - /*source=*/source, - /*indices=*/SmallVector(readRank, zero), - /*padding=*/padValue, - /*inBounds=*/inBoundsVal); - - if (llvm::equal(readShape, sourceShape) || useInBoundsInsteadOfMasking) - return transferReadOp; - SmallVector mixedSourceDims = - tensor::getMixedSizes(builder, loc, source); - Value mask = - builder.create(loc, maskType, mixedSourceDims); - return vector::maskOperation(builder, transferReadOp, mask)->getResult(0); -} - -/// Vectorize a `tensor::expandshape` to these 3 Ops: -/// Vector::TransferReadOp - Reads a vector from the source tensor -/// ShapeCastOp - Reshape the data based on the target. -/// vector::TransferWriteOp. - Write the result vector back to the destination -/// tensor -template -LogicalResult lowerTensorReshapeOp(RewriterBase &rewriter, Operation *inputOp, - SmallVectorImpl &newResults, - ArrayRef inputVectorSizes) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(inputOp); - - auto src = inputOp->getOperand(0); - auto srcType = cast(src.getType()); - OpResult result = inputOp->getResults()[0]; - auto resultType = cast(result.getType()); - ArrayRef resultShape = resultType.getShape(); - ArrayRef srcShape = srcType.getShape(); - Location loc = inputOp->getLoc(); - - SmallVector srcVectorizedShape(srcType.getRank()); - llvm::SmallDenseMap shapeScales; - ReshapeVectorizeHelper reshapeHelper(srcVectorizedShape, shapeScales, - resultShape, srcShape); - - srcVectorizedShape.assign(srcShape.begin(), srcShape.end()); - if (!inputVectorSizes.empty()) { - reshapeHelper.getScalesDim(inputVectorSizes); - getReshapeOperationVectorizeShape(reshapeHelper); - } - // generate read operation - auto padValue = rewriter.create( - loc, rewriter.getZeroAttr(srcType.getElementType())); - Value readResult = vector::createReadOrMaskedRead( - rewriter, loc, src, - inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape, - padValue, true); - - auto shapeCastType = - VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes, - resultType.getElementType()); - vector::ShapeCastOp shapeCastOp = - rewriter.create(loc, shapeCastType, readResult); - - // generate write operation - SmallVector destSizes(resultShape.size()); - llvm::transform(resultShape, std::begin(destSizes), [&rewriter](size_t size) { - return rewriter.getIndexAttr(size); - }); - - Operation *write = createWriteOrMaskedWrite( - rewriter, loc, shapeCastOp->getResults()[0], destSizes, - inputVectorSizes.empty() ? resultShape : inputVectorSizes, true); - newResults.push_back(write->getResult(0)); - return success(); -} - -/// Vectorize a `tensor::concat` to these 3 Ops: -/// Tensor::EmptyOp - The result tensor. -/// Vector::TransferWriteOp - Write the result vector back to the destination -/// tensor. -/// Vector::TransferWriteOp - Write the result vector back to the destination -/// tensor. -LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, - tensor::ConcatOp concatOp, - SmallVectorImpl &newResults) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(concatOp); - - Location loc = concatOp.getLoc(); - FailureOr dest = - tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0)); - if (failed(dest)) - return failure(); - - auto empty = dest->getDefiningOp(); - if (!empty) - return failure(); - - // Compute the partial sums for the slice offsets. - int64_t dim = concatOp.getDim(); - Value dimValue = - rewriter.create(loc, rewriter.getIndexAttr(dim)); - - int64_t rank = concatOp.getResultType().getRank(); - auto srcType = cast(concatOp->getResultTypes()[0]); - auto padValue = rewriter.create( - loc, rewriter.getZeroAttr(srcType.getElementType())); - - // Construct the chain of insert_slice ops into the destination. - Value result = *dest; - Value previous_offset = rewriter.create(loc, 0); - for (auto input : concatOp.getInputs()) { - - SmallVector sizes = - tensor::getMixedSizes(rewriter, loc, input); - auto inputType = cast(input.getType()); - - Value readResult = createReadOrMaskedRead( - rewriter, loc, input, inputType.getShape(), padValue, true); - - Value zero = rewriter.create(loc, 0); - SmallVector indices(rank, zero); - // update write position - indices[dim] = previous_offset; - result = rewriter - .create( - loc, readResult, result, indices, - rewriter.getMultiDimIdentityMap(rank)) - ->getResults()[0]; - auto dimOp = rewriter.create(loc, input, dimValue); - previous_offset = - rewriter.create(loc, dimOp, previous_offset); - } - - newResults.push_back(result); - return success(); -} - -LogicalResult convert2TargetOperation(RewriterBase &rewriter, Operation *op, - ArrayRef inputVectorSizes = {}) { - LDBG("Attempting to vectorize:\n" << *op << "\n"); - LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - - if (failed(lowerTargetOpPrecondition(op, inputVectorSizes))) { - LDBG("Vectorization pre-conditions failed\n"); - return failure(); - } - - SmallVector results; - auto lowerResult = - TypeSwitch(op) - .Case([&](auto expandShapeOp) { - return lowerTensorReshapeOp( - rewriter, expandShapeOp, results, inputVectorSizes); - }) - .Case([&](auto collapseShapeOp) { - return lowerTensorReshapeOp( - rewriter, collapseShapeOp, results, inputVectorSizes); - }) - .Case([&](auto concatOp) { - return lowerTensorConcatOp(rewriter, concatOp, results); - }) - .Default([](auto) { return failure(); }); - - if (failed(lowerResult)) { - LDBG("Lower failed\n"); - return failure(); - } - - if (!results.empty()) - rewriter.replaceOp(op, results); - else - rewriter.eraseOp(op); - - return success(); -} - -template -struct OperationConvertTileVectorPass : public RewritePattern { - -private: - /// specify vectorize size - SmallVector inputVectorSizes; - /// keep those parameters for future use - bool vectorizeNDExtract, flatten1DDepthwiseConv; - -public: - explicit OperationConvertTileVectorPass( - MLIRContext *context, ArrayRef inputVectorSizes = {}, - bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), - inputVectorSizes(inputVectorSizes), - vectorizeNDExtract(vectorizeNDExtract), - flatten1DDepthwiseConv(flatten1DDepthwiseConv) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - - auto targetOp = dyn_cast(op); - if (!targetOp) - return rewriter.notifyMatchFailure(op, "Not expected operations."); - - // linalg.fill + linalgx.batch_mutmul should not be lower to vector - // because these two operation is needed by brgemm kernel. - if (isMatchedOperationSet(op)) { - return rewriter.notifyMatchFailure( - op, "linalg.fill + linalg.matmul can't do lowering."); - } - SmallVector scalableVecDims(inputVectorSizes.size(), false); - if (failed(linalg::vectorize(rewriter, op, - /*inputVectorSizes=*/inputVectorSizes, - /*inputScalableVecDims=*/scalableVecDims, - vectorizeNDExtract, flatten1DDepthwiseConv))) { - return rewriter.notifyMatchFailure(op, "Fail to vectorize."); - } - return success(); - } -}; - -/// Lower tensor.unpack operation to vector. -/// -/// The reason why we don't use `OperationConvertTileVectorPass` is we -/// need to specify input vector size due to unpack operation does not support -/// empty vector size. It's logic is not consistent with other tensor operation. -/// It would be better we split this process logic as a standalone class to -/// notify unpack operation is not support empty vector size. We need to support -/// it like other operation in the future. -/// -/// TODO: Need to support upstream to handle empty vector size. Currently -/// upstream folks don't allow me to do this. It's weird, I can't find reason. -struct TensorUnpackConvertVectorPass : public RewritePattern { - - explicit TensorUnpackConvertVectorPass(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - - auto tensorUnPackOp = dyn_cast(op); - if (!tensorUnPackOp) - return rewriter.notifyMatchFailure(op, "Not expected operations."); - - auto resultTy = cast(op->getResultTypes()[0]); - // TODO: Need to support upstream to handle empty vector size. Currently - // upstream folks don't allow me to do this. - ArrayRef inputShape = resultTy.getShape(); - SmallVector targetVecDims(inputShape.size(), false); - - if (failed(linalg::vectorize(rewriter, op, - /*inputVectorSizes=*/inputShape.vec(), - /*inputScalableVecDims=*/targetVecDims, false, - false))) { - return rewriter.notifyMatchFailure(op, "Fail to vectorize."); - } - return success(); - } -}; - -/// Some tensor operation lowering to vector. -/// -/// Currently support expand_shape, collapse_shape and concat_shape. -/// May need support other operation in the future. -struct TensorOpConvertVectorPass : public RewritePattern { -private: - SmallVector inputVectorSizes; - -public: - explicit TensorOpConvertVectorPass(MLIRContext *context, - ArrayRef inputVectorSizes = {}) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), - inputVectorSizes(inputVectorSizes) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - - bool is_target = isRequiredTensorOp(op); - if (!is_target) - return rewriter.notifyMatchFailure(op, "Not expected operations."); - - if (failed(convert2TargetOperation(rewriter, op, inputVectorSizes))) { - return rewriter.notifyMatchFailure(op, "Fail to vectorize."); - } - return success(); - } -}; - -/// Patterns that lower to tile (virtual) vector. -void populateLowerToTileVectorPatterns(RewritePatternSet &patterns) { - patterns.add, - OperationConvertTileVectorPass>( - patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); -} - -/// LowerToTileVectorPass is a pass that lowers operations to tile (virtual) -/// vector. We must aware that this pass do not support dynamic shape currently. -struct LowerTileVectorPass - : public impl::LowerToTileVectorBase { - void runOnOperation() final { - // - auto *ctx = &getContext(); - RewritePatternSet patternsInit(ctx); - auto funcOp = getOperation(); - - tensor::ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) { - Operation *producer = fusedOperand->get().getDefiningOp(); - return producer && producer->hasOneUse(); - }; - // Some operation convert as constant, this pattern can help us to improve - // the performance. - tensor::populateRewriteAsConstantPatterns(patternsInit, defaultControlFn); - // Remove unnessary operation like extract slice and insert slice - tensor::populateReassociativeReshapeFoldingPatterns(patternsInit); - tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patternsInit); - tensor::populateFoldTensorSubsetOpPatterns(patternsInit); - // Pad operation will lower to linalg.fill. We lower it in init patterns - // then lower the fill operation in second patterns. - linalg::populatePadOpVectorizationPatterns(patternsInit); - - GreedyRewriteConfig configInit; - // Init patterns use to remove useless tensor operation like extract or - // insert slice. - configInit.strictMode = GreedyRewriteStrictness::ExistingOps; - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patternsInit), - configInit); - - RewritePatternSet firstPatterns(ctx); - // All the dynamic shape will reject to lower. - populateLowerToTileVectorPatterns(firstPatterns); - GreedyRewriteConfig configFirstPn; - // We only apply the lowering pattern on existing operations - configFirstPn.strictMode = GreedyRewriteStrictness::ExistingOps; - (void)applyPatternsAndFoldGreedily(funcOp, std::move(firstPatterns), - configFirstPn); - // Error case: - // ``` - // linalg.copy : <1x32xf32> - // tensor.insert_slice tensor<1x32xf32> to tensor<1x128x1x32xf32> - // --> lowering as: - // transfer_write : permutation map = (d0, d1, d2, d3) -> (d0, d3) - // ``` - // Inorder to avoid the fold greedily bug (fold wrong - // permution map for the transfer_write operation). Give it the new full IR - // to fold second time can fold correctly. - RewritePatternSet secondPattern(ctx); - // Ensure each operation has a clear semantics, rather than a composite - // semantics. Instead of leaving it to the subsequent passes to handle these - // complex semantics, it reduces the difficulty of handling operations in - // the subsequent passes. Like transfer_read and transfer_write may have - // transpose or braodcast semantic etc. - vector::populateVectorTransferPermutationMapLoweringPatterns(secondPattern); - // Remove unnessary broadcast operation - vector::populateSinkVectorBroadcastPatterns(secondPattern); - // Second fold can help us to eliminate redundant operation like consecutive - // read and write. - (void)applyPatternsAndFoldGreedily(funcOp, std::move(secondPattern)); - // may need other patterns to reduce redundant operations - } -}; -} // namespace - -std::unique_ptr createLowerTileVectorPass() { - return std::make_unique(); -} -} // namespace gc -} // namespace mlir \ No newline at end of file From 788452e602c9349a278d578c3786683a7aa0c753 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 22 Aug 2024 09:33:40 +0800 Subject: [PATCH 26/66] update --- include/gc/Transforms/TilingVector.h | 3 +- lib/gc/Transforms/CMakeLists.txt | 3 +- ...erPass.cpp => CPUPhysicalRegisterPass.cpp} | 309 ++++++++++++++---- .../gc/transforms/cpu-vetor-distribution.mlir | 2 +- 4 files changed, 253 insertions(+), 64 deletions(-) rename lib/gc/Transforms/{CPUPhysicalResigterPass.cpp => CPUPhysicalRegisterPass.cpp} (93%) diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index 3ccd7ed7e..4a6531ce8 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -420,7 +420,8 @@ class CanonicalizerCommonUsedData : public TypeHelper { void updateOpOperandResultInGroups(size_t opGid, Operation *op, const Value &init = Value(), const Value &result = Value()); - void removeOpInCurrentGroups(size_t grpIdx, Operation *op); + void removeOpInCurrentGroups(size_t grpIdx, Operation *op, + Operation *replacedOp); void updateOpGroupInfo(size_t grpIdx); Value diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index f07ae37f7..486c1dc67 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -18,13 +18,12 @@ gc_add_mlir_library(GcPasses VerifyTargetDescription.cpp DeepTileContractionNamedOp.cpp TilingUtil.cpp - LowerTileVectorPass.cpp - CPUPhysicalResigterPass.cpp SinkOpIntoInnerLoop.cpp MergeNestedForall.cpp FoldTensorOperation.cpp LowerToTileVector.cpp + CPUPhysicalRegisterPass.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp similarity index 93% rename from lib/gc/Transforms/CPUPhysicalResigterPass.cpp rename to lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 9d4e6e657..d026a5a3a 100644 --- a/lib/gc/Transforms/CPUPhysicalResigterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -6,7 +6,22 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "gc/Transforms/TilingVector.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include namespace mlir { namespace gc { @@ -16,7 +31,8 @@ namespace { #define DEBUG_TYPE "lower-to-physical-register-pass" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") +#define SAFE_EXPAND(X) X +#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") #define ARITH_CAST_OPERATIONS \ arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::BitcastOp, \ @@ -29,12 +45,13 @@ namespace { linalg::MatmulOp, linalg::BatchMatmulOp, \ linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ - linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp + linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ + microkernel::BrgemmOp // TODO: remove it in the future bool disableSpecialOp = false; -bool disableBroadcastOp = true; -bool enableDebugPrinter = false; +bool disableBroadcastOp = false; +bool enableDebugPrinter = true; void printQueue(const std::queue &opQueue) { auto tempQ(opQueue); @@ -59,6 +76,11 @@ void printGroupOps(SmallVector, 8> &opGroups) { } } +static inline bool isCandidateMoveOperations(Operation *op) { + return isa(op); +} + static inline bool isSpecialLinalgOp(Operation *op) { return isa(op); } @@ -646,9 +668,8 @@ T getInitValForReduce(vector::CombiningKind kind, Type t) { // Since we rewrite transfer_read and transfer_write, the `permutationmap` must // be changed. void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, - const RankedTensorType &tensorType, + const ShapedType &tensorType, const AffineMap &permutationMap) { - auto dimExpr = permutationMap.getResults(); auto lastDim = dyn_cast(dimExpr.back()); assert(isa(lastDim)); @@ -659,11 +680,11 @@ void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, rewriter.getContext()); SmallVector inBounds(1, true); if (mlir::isa(op)) { - auto transferWriteOp = mlir::dyn_cast(op); + auto transferWriteOp = cast(op); transferWriteOp.setPermutationMap(destAffineMap); transferWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); } else if (mlir::isa(op)) { - auto transferReadOp = mlir::dyn_cast(op); + auto transferReadOp = cast(op); transferReadOp.setPermutationMap(destAffineMap); transferReadOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); } @@ -973,8 +994,7 @@ vector::TransferReadOp ForLoopGenerator::cloneReductionTransferRead( : getScalarType(newReadOp); newReadOp->getResult(0).setType(newOperandType); setOpVectorizationPermutationMap( - newReadOp, b, - mlir::dyn_cast(newReadOp.getSource().getType()), + newReadOp, b, cast(newReadOp.getSource().getType()), newReadOp.getPermutationMap()); rewriter.replaceOp(readOp, newReadOp); @@ -992,8 +1012,7 @@ makeNewTransferWriteOp(Value source, IRMapping &writeMap, OpBuilder &b, updateReduceReadWriteOperationOperand(inductionVars, parallelAxis, newWriteOp, MultiReduceOpAxisKind::Parallel); setOpVectorizationPermutationMap( - newWriteOp, b, - mlir::dyn_cast(newWriteOp->getResult(0).getType()), + newWriteOp, b, cast(newWriteOp->getResult(0).getType()), newWriteOp.getPermutationMap()); bodyRewriter.replaceOp(writeOp, newWriteOp); return newWriteOp; @@ -2263,7 +2282,7 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { // don't need to do the transpose if (tpCanonicalizer.isTransposeOnAllOneDim()) { - removeOpInCurrentGroups(grpIdx, tpOp); + removeOpInCurrentGroups(grpIdx, tpOp, tpOp->getOperand(0).getDefiningOp()); // generate nested for loop SmallVector nextLoopResults; @@ -2338,11 +2357,12 @@ SmallVector &SpecialOperationCanonicalizer::getCandidateOps() { }; void MultiReductionCanonicalizer::initReductionAxis() { - auto reductionAxisRange = - getCandidateOps()[0].getReductionDims().getAsValueRange(); - auto reductionRange = llvm::to_vector<4>(map_range( - reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); - reductionAxis.assign(reductionRange.begin(), reductionRange.end()); + // auto reductionAxisRange = + // getCandidateOps()[0].getReductionDims().getAsValueRange(); + // auto reductionRange = llvm::to_vector<4>(map_range( + // reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); + auto reductionAxisRange = getCandidateOps()[0].getReductionDims(); + reductionAxis.assign(reductionAxisRange.begin(), reductionAxisRange.end()); llvm::sort(reductionAxis); } @@ -2549,8 +2569,6 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { } else if (isa(op)) { getShapeCastCanonicalizers().back().getCandidateOps().emplace_back( cast(op)); - llvm::outs() << " current shape cast op: "; - op->dump(); } } } @@ -2773,6 +2791,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap, DenseMap> &indiceLoopMap) { + llvm::outs() << groupIdx << "\n"; const int loop_step = getFusionStrategy().getGroupMaxSteps()[groupIdx]; // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); @@ -2996,11 +3015,13 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { return TypeSwitch(op) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { - auto rdDimsRange = multiReductionOp.getReductionDims() - .getAsValueRange(); - auto reductionDims = llvm::to_vector<4>(map_range( - rdDimsRange, [](const APInt &a) { return a.getZExtValue(); })); - dataAxis.assign(reductionDims.begin(), reductionDims.end()); + // auto rdDimsRange = multiReductionOp.getReductionDims() + // .getAsValueRange(); + // auto reductionDims = llvm::to_vector<4>(map_range( + // rdDimsRange, [](const APInt &a) { return a.getZExtValue(); + // })); + auto rdDimsRange = multiReductionOp.getReductionDims(); + dataAxis.assign(rdDimsRange.begin(), rdDimsRange.end()); }) .Case([&](vector::ShapeCastOp shapeCastOp) { auto srcType = shapeCastOp.getSourceVectorType(); @@ -3168,8 +3189,8 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return true; }) .Case([&](vector::BroadcastOp broadcastOp) { - if (isa(op2)) { - return false; + if (isSpecialOp(op2)) { + return true; } return !OpTrait::util::staticallyKnownBroadcastable( getOperationVectorType(op1, false)->getShape(), @@ -3179,29 +3200,53 @@ bool hasDataDependency(Operation *op1, Operation *op2) { SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); - if (!isSpecialOp(op2)) { - return hasSameAxis(dims1, dims2); - } return true; + // if (!isSpecialOp(op2)) { + // return hasSameAxis(dims1, dims2); + // } + // return true; }) .Default([&](Operation *op) { return false; }); return res; } +/// Get the operation which is not a read-write in current queue +/// \param [in, out] op +Operation *getNotReadWriteOperaiton(std::queue &tmpQ) { + Operation *op = nullptr; + while (!tmpQ.empty()) { + Operation *cur = tmpQ.front(); + tmpQ.pop(); + if (isReadOrWriteOperation(cur)) { + continue; + } + op = cur; + } + return op; +} + bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { // 1. check previous operation if (!opGroups.back().empty()) { - auto prevOp = opGroups.back().back(); + + // We only care about the calculation operation. + std::queue tmpQ(opGroups.back()); + Operation *prevOp = nullptr; + prevOp = getNotReadWriteOperaiton(tmpQ); + if (!prevOp) { + return false; + } + // not in the same operation if (prevOp->getParentOp() != op->getParentOp()) { return true; } // read and write operation dependency - if (readWriteDependency(prevOp, op)) { - return true; - } + // if (readWriteDependency(prevOp, op)) { + // return true; + // } // special operation need to check data dependency axis if (hasDataDependency(prevOp, op)) { @@ -3257,7 +3302,7 @@ void VectorFusionStrategy::classifyOperations() { addOperationToGroup(op); } else if (isSpecialLinalgOp(op)) { // following operation need a new group - if (opGroups.back().size() > 0) { + if (!opGroups.back().empty()) { opGroups.emplace_back(std::queue()); } } @@ -3363,8 +3408,7 @@ void ForLoopGenerator::createNewConstantOp( {*transferWriteOp, transferWriteOp->getPermutationMap()}); setOpVectorizationPermutationMap( *transferWriteOp, srcWriter, - mlir::dyn_cast( - transferWriteOp->getResults()[0].getType()), + cast(transferWriteOp->getResults()[0].getType()), transferWriteOp->getPermutationMap()); return; } @@ -3405,7 +3449,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( setOpVectorizationPermutationMap( transferWriteOp, rewriter, - dyn_cast( + cast( transferWriteOp->getResult(0).getType()), transferWriteOp.getPermutationMap()); } @@ -3419,8 +3463,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( transferReadOp->getResult(0).setType(newOperandType); setOpVectorizationPermutationMap( transferReadOp, rewriter, - mlir::dyn_cast( - transferReadOp.getSource().getType()), + cast(transferReadOp.getSource().getType()), transferReadOp.getPermutationMap()); return success(); @@ -3470,8 +3513,8 @@ mlir::FailureOr getOperationOperateTensor(Operation *op) { }); } -void CanonicalizerCommonUsedData::removeOpInCurrentGroups(size_t grpIdx, - Operation *op) { +void CanonicalizerCommonUsedData::removeOpInCurrentGroups( + size_t grpIdx, Operation *op, Operation *replacedOp) { std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[grpIdx]); std::queue newOpQueue; while (!tmpOpQueue.empty()) { @@ -3487,19 +3530,18 @@ void CanonicalizerCommonUsedData::removeOpInCurrentGroups(size_t grpIdx, getFusionStrategy().getOpGroups()[grpIdx] = newOpQueue; // erase and replace the operation - Operation *defOp = op->getOperand(0).getDefiningOp(); SmallVector usesOp(op->getUsers().begin(), op->getUsers().end()); IRRewriter rewriter(op); - rewriter.replaceOp(op, op->getOperand(0)); + rewriter.replaceOp(op, replacedOp); // update removed operation related operation anchor position - getFusionStrategy().getOpAnchorPos()[defOp] = - getOperationMaxVectorType(defOp)->getRank() - 1; + getFusionStrategy().getOpAnchorPos()[replacedOp] = + getOperationMaxVectorType(replacedOp)->getRank() - 1; for (Operation *x : usesOp) { getFusionStrategy().getOpAnchorPos()[x] = getOperationMaxVectorType(x)->getRank() - 1; } - // update operaiton in grpIdx group related information + // update operation in grpIdx group related information updateOpGroupInfo(grpIdx); } @@ -3741,6 +3783,7 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); DenseMap &OpAnchorPos = getFusionStrategy().getOpAnchorPos(); + IRRewriter rewriter(func); analysisGroupMaxSteps(); @@ -3827,6 +3870,7 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { if (idx == 0) continue; } + if (isa(op)) { if (idx == 1) { // accumulate value, just empty tensor is okay @@ -3849,7 +3893,7 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { if (isa(constantOp.getValue())) { VectorType newOperandType = getVectorzedType(op, groupSteps); - auto valueType = dyn_cast(constantOp.getValue()); + auto valueType = cast(constantOp.getValue()); if (valueType.isSplat()) { FailureOr res = createArithSplatConstantOp( rewriter, constantOp->getLoc(), valueType, newOperandType); @@ -3858,21 +3902,19 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { } op->setOperand(idx, res.value()); } else { - // TODO: need to test not splat value - // newConstantOp = srcWriter.create( - // srcOp->getLoc(), srcConstantOp.getValue()); - // if (!srcOpCanoniclizedMap.contains(sourceOp)) { - // auto [tsr, writeOpresult] = - // canonicalizeSourceOperation(sourceOp, visitedOperation); - // srcOpCanoniclizedMap.insert({sourceOp, {tsr, - // writeOpresult}}); - // } - // auto opInit = canonicalizeCurrentOperation( - // op, srcOpCanoniclizedMap[sourceOp].second, idx); - // updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); llvm::llvm_unreachable_internal( "Can't support not splat constant value."); } + + // transfer read operation just use the constant value to do + // calculation, don't need to read. + if (isa(op)) { + if (idx == 0) { + removeOpInCurrentGroups(opGroupIndexMap[op], op, + op->getOperand(0).getDefiningOp()); + continue; + } + } } } } @@ -4022,6 +4064,149 @@ void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { moveLoopInvariantCode(forOp.value()); } +LogicalResult +moveFront(Operation *op, + llvm::DenseMap &operationPosition) { + IRRewriter rewriter(op); + Operation *backOperation; + size_t pos = 0; + // check all the operand is block argument + bool allBlockArgs = true; + for (auto operand : op->getOperands()) { + if (!isa(operand)) { + allBlockArgs = false; + break; + } + } + if (allBlockArgs) { + rewriter.moveOpAfter(op, op->getBlock(), op->getBlock()->begin()); + return success(); + } + for (auto operand : op->getOperands()) { + if (isa(operand)) { + continue; + } + if (operationPosition[operand.getDefiningOp()] > pos and + operand.getDefiningOp()->getBlock() == op->getBlock()) { + backOperation = operand.getDefiningOp(); + pos = operationPosition[operand.getDefiningOp()]; + } + } + if (pos == 0) { + // extract operand operation all in previous block + rewriter.moveOpBefore(op, op->getBlock(), op->getBlock()->begin()); + return success(); + } + if (backOperation) { + rewriter.moveOpAfter(op, backOperation); + return success(); + } + return failure(); +} + +LogicalResult moveBack(Operation *op, + llvm::DenseMap &operationPosition) { + IRRewriter rewriter(op); + Operation *firstOperation; + size_t pos = 0; + for (auto user : op->getUsers()) { + if (operationPosition[user] > pos and user->getBlock() == op->getBlock()) { + firstOperation = user; + pos = operationPosition[user]; + } + } + if (pos == 0) { + // Don't move. + // TODO: need to consider move to before the block which use it. + return success(); + } + if (firstOperation) { + rewriter.moveOpBefore(op, firstOperation); + return success(); + } + return failure(); +} + +void moveCandidateOperation( + llvm::DenseMap &operationPosition, + ArrayRef candidateOps) { + + for (Operation *op : candidateOps) { + auto ret = + TypeSwitch(op) + .Case([&](affine::AffineApplyOp affineOp) { + return moveFront(op, operationPosition); + }) + .Case( + [&](tensor::ExtractSliceOp extractOp) { + return moveFront(op, operationPosition); + }) + .Case([&](tensor::InsertSliceOp insertOp) { + return moveBack(op, operationPosition); + }) + .Case([&](vector::TransferReadOp readOp) { + return moveFront(op, operationPosition); + }) + .Case( + [&](vector::TransferWriteOp writeOp) { + return moveBack(op, operationPosition); + }) + .Default([&](Operation *op) { return success(); }); + if (failed(ret)) { + LDBG("Wrong to move operation:" << *op << "\n"); + return; + } + } +} + +// Need to move some operations like extract_slice or insert_slice. +// Because those operation may interpret our analysis result. e.g.: +// ``` +// clang-format off + // %21 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> + // %22 = arith.addf %21, %20 : vector<16x16xf32> + // %23 = vector.transfer_write %22, %extracted_slice_12[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> + // %inserted_slice_13 = tensor.insert_slice %18 into %arg14[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> + // %extracted_slice_14 = tensor.extract_slice %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<32x16xf32> to tensor<16x16xf32> + // %24 = vector.transfer_read %cst_0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> + // %25 = arith.maximumf %22, %24 : vector<16x16xf32> + // %26 = vector.transfer_write %25, %extracted_slice_14[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> + // %inserted_slice_15 = tensor.insert_slice %23 into %arg15[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> + // %inserted_slice_16 = tensor.insert_slice %26 into %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> +// clang-format on +// ``` +// The maximumf and addf operation can be a same group, but the extract_slice +// operation interpret us. +// The move operation(extra_slice) will check its parameters. In order to +// ensure that it does not affect the correctness of the result, we will only +// move the moved op after the op to which the parameters belong to. If it's +// operand is all the block argument, we will move it to the begining of the +// block. +// insert_slice just move them to the privious of the first operation which +// use it. +void moveSomeInterferenceOperation( + func::FuncOp *func, MLIRContext *ctx, + std::function &conditionalFunc) { + // Pre-order traversal of each op + // Record each operation position. Inorder to we can kown current operation + // should move after which operation. + llvm::DenseMap operationPosition; + SmallVector candidateOps; + size_t opCounter = 0; + + // get the position of each operation + func->walk([&](Operation *op) { + operationPosition[op] = opCounter++; + if (conditionalFunc(op)) { + candidateOps.emplace_back(op); + } + }); + moveCandidateOperation(operationPosition, candidateOps); + // eliminate some useless operation + RewritePatternSet patterns(ctx); + (void)applyPatternsAndFoldGreedily(*func, std::move(patterns)); +} + /// Pass that lower to physical vector. struct CPUPhysicalRegisterPass : public impl::CPUPhysicalRegisterPassBase { @@ -4036,6 +4221,8 @@ struct CPUPhysicalRegisterPass LDBG("Not support operation appears in current function."); return; } + std::function candidateFunc = isCandidateMoveOperations; + moveSomeInterferenceOperation(&func, ctx, candidateFunc); // canonicalize vector operation, default use vector-based fusion // strategy. HardWareInfo hwInfo; @@ -4044,6 +4231,8 @@ struct CPUPhysicalRegisterPass CanonicalizerVectorOperation canonicalizer( func, CanonicalizerKind::OperationsGroup, hwInfo); canonicalizer.run(); + candidateFunc = isReadOrWriteOperation; + moveSomeInterferenceOperation(&func, ctx, candidateFunc); // transpose kernel vector::VectorTransformsOptions transposeOptions = diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/gc/transforms/cpu-vetor-distribution.mlir index 6eb89574a..b6310946a 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/gc/transforms/cpu-vetor-distribution.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --split-input-file --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all | FileCheck %s +// RUN: gc-opt %s --split-input-file --fold-tensor-operation --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all | FileCheck %s // CHECK-LABEL: func @add_tensor_test0 func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { From 894f268aed5b71e2c758b33327e40f8eb884db23 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 22 Aug 2024 10:57:49 +0800 Subject: [PATCH 27/66] disable printer --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index d026a5a3a..0d7f51ec5 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -51,7 +51,7 @@ namespace { // TODO: remove it in the future bool disableSpecialOp = false; bool disableBroadcastOp = false; -bool enableDebugPrinter = true; +bool enableDebugPrinter = false; void printQueue(const std::queue &opQueue) { auto tempQ(opQueue); From 37a0447e4041da36110b257ac436c76ba944974b Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 23 Aug 2024 21:02:44 +0800 Subject: [PATCH 28/66] fix transpose segmentation fault --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 151 ++++++++---------- 1 file changed, 63 insertions(+), 88 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 0d7f51ec5..4800cadce 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -6,22 +6,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "gc/Transforms/TilingVector.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/Visitors.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/Support/raw_ostream.h" -#include -#include -#include namespace mlir { namespace gc { @@ -45,8 +30,7 @@ namespace { linalg::MatmulOp, linalg::BatchMatmulOp, \ linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ - linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ - microkernel::BrgemmOp + linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp // TODO: remove it in the future bool disableSpecialOp = false; @@ -124,7 +108,7 @@ static inline bool isSpecialOp(Operation *op) { bool is_innermost_operation(Operation *op) { bool inner_most = true; op->walk([&inner_most](Operation *p) { - if (mlir::isa(p)) { + if (isa(p)) { inner_most = false; return WalkResult::interrupt(); } @@ -1888,14 +1872,10 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { opBuilder, grpIdx, 0, currentLoopStateIdxMap, initArgs, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, loopArgsOriginalOperandMap, forResultOrignalResultMap, indiceLoopMap); + DenseSet forOpChildOps; + forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); auto replaceIfFn = [&](OpOperand &use) { - auto walkResult = forOp->walk([&](Operation *op) { - if (use.getOwner() == op) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return walkResult != WalkResult::interrupt(); + return not forOpChildOps.contains(use.getOwner()); }; for (auto x : nextAnchorResults) { auto originalResult = forResultOrignalResultMap[x]; @@ -1918,7 +1898,7 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( const ValueRange &iterArgs, DenseMap &tpAxisMap) { auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; - VectorType vtType = tpOp.getResultVectorType(); + VectorType vtType = tpOp.getSourceVectorType(); size_t rank = vtType.getRank(); auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); @@ -1947,20 +1927,19 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( auto padValue = b.create( loc, b.getZeroAttr(vtType.getElementType())); SmallVector inBoundsVal(1, true); - - auto transferReadOp = b.create( - loc, - /*vectorType=*/kernelType, - /*source=*/readSourceOp.getSource(), - /*indices=*/inductionVars, - /*padding=*/padValue, - /*inBounds=*/inBoundsVal); SmallVector writeVars; size_t itrIdx = 0; while (itrIdx < rank) { writeVars.emplace_back(inductionVars[tpAxisMap[itrIdx]]); itrIdx++; } + auto transferReadOp = b.create( + loc, + /*vectorType=*/kernelType, + /*source=*/readSourceOp.getSource(), + /*indices=*/writeVars, + /*padding=*/padValue, + /*inBounds=*/inBoundsVal); rectifyWriteOperationIndice(&successorWriteOp, writeVars); @@ -2281,52 +2260,53 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { SmallVector inductionVars; // don't need to do the transpose - if (tpCanonicalizer.isTransposeOnAllOneDim()) { - removeOpInCurrentGroups(grpIdx, tpOp, tpOp->getOperand(0).getDefiningOp()); - - // generate nested for loop - SmallVector nextLoopResults; - DenseMap resultIdxMap; - SmallVector inductionVars; - DenseMap forResultOrignalResultMap; - Operation *firstOp = getFusionStrategy().getOpGroups()[grpIdx].front(); - OpBuilder b(firstOp); - VectorType groupVector = - getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; - ArrayRef shapes = groupVector.getShape(); - - DenseMap> indiceLoopMap; - - scf::ForOp forOp = constructNestedForOp( - 0, grpIdx, b, firstOp->getLoc(), iterArgs, shapes, inductionVars, - operandIdxMap, originalOperandMap, operandOriginalMap, nextLoopResults, - resultIdxMap, forResultOrignalResultMap, indiceLoopMap); - - auto replaceIfFn = [&](OpOperand &use) { - auto walkResult = forOp->walk([&](Operation *op) { - if (use.getOwner() == op) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return walkResult != WalkResult::interrupt(); - }; - for (auto x : nextLoopResults) { - auto originalResult = forResultOrignalResultMap[x]; - rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), - forOp->getResults()[resultIdxMap[x]], - replaceIfFn); - rectifyGroupOperands(grpIdx, originalResult, - forOp->getResults()[resultIdxMap[x]]); - } - // clear current group operation - clearCurrentOperationGroup(grpIdx); - return forOp; - } + // if (tpCanonicalizer.isTransposeOnAllOneDim()) { + // removeOpInCurrentGroups(grpIdx, tpOp, + // tpOp->getOperand(0).getDefiningOp()); + + // // generate nested for loop + // SmallVector nextLoopResults; + // DenseMap resultIdxMap; + // SmallVector inductionVars; + // DenseMap forResultOrignalResultMap; + // Operation *firstOp = getFusionStrategy().getOpGroups()[grpIdx].front(); + // OpBuilder b(firstOp); + // VectorType groupVector = + // getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; + // ArrayRef shapes = groupVector.getShape(); + + // DenseMap> indiceLoopMap; + + // scf::ForOp forOp = constructNestedForOp( + // 0, grpIdx, b, firstOp->getLoc(), iterArgs, shapes, inductionVars, + // operandIdxMap, originalOperandMap, operandOriginalMap, + // nextLoopResults, resultIdxMap, forResultOrignalResultMap, + // indiceLoopMap); + + // forOp->dump(); + // DenseSet forOpChildOps; + // forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); + // auto replaceIfFn = [&](OpOperand &use) { + // return not forOpChildOps.contains(use.getOwner()); + // }; + // for (auto x : nextLoopResults) { + // auto originalResult = forResultOrignalResultMap[x]; + // rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), + // forOp->getResults()[resultIdxMap[x]], + // replaceIfFn); + // rectifyGroupOperands(grpIdx, originalResult, + // forOp->getResults()[resultIdxMap[x]]); + // } + // // clear current group operation + // clearCurrentOperationGroup(grpIdx); + // return forOp; + // } OpBuilder b(tpOp); int tpStep = TransposeCanonicalizer::TRANSPOSE_KERNEL::KERNEL_16X16; // only contains last dim can use fast transpose algorithm - if (permuteSet.contains(rank - 1) and isTwoDTranspose) { + if ((tpCanonicalizer.getFirstTpIdx() == (rank - 1) or + tpCanonicalizer.getSecondTpIdx() == (rank - 1)) and + isTwoDTranspose) { scf::ForOp forOp = generateTransposeForLoopWithLastDim( b, grpIdx, 0, tpStep, tpOp.getLoc(), inductionVars, iterArgs, operandIdxMap, originalOperandMap); @@ -2791,7 +2771,6 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap, DenseMap> &indiceLoopMap) { - llvm::outs() << groupIdx << "\n"; const int loop_step = getFusionStrategy().getGroupMaxSteps()[groupIdx]; // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); @@ -3197,10 +3176,10 @@ bool hasDataDependency(Operation *op1, Operation *op2) { getOperationVectorType(op2)->getShape()); }) .Case([&](vector::TransposeOp transposeOp) { - SmallVector dims1, dims2; - getOperationDataAxis(op1, dims1); - getOperationDataAxis(op2, dims2); return true; + // SmallVector dims1, dims2; + // getOperationDataAxis(op1, dims1); + // getOperationDataAxis(op2, dims2); // if (!isSpecialOp(op2)) { // return hasSameAxis(dims1, dims2); // } @@ -3981,14 +3960,10 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( 0, groupId, rewriter, rewriter.getUnknownLoc(), forIterArgs, shapes, inductionVars, operandIdxMap, originalOperandMap, operandOriginalMap, nextLoopResults, resultIdxMap, forResultOrignalResultMap, indiceLoopMap); + DenseSet forOpChildOps; + forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); auto replaceIfFn = [&](OpOperand &use) { - auto walkResult = forOp->walk([&](Operation *op) { - if (use.getOwner() == op) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return walkResult != WalkResult::interrupt(); + return not forOpChildOps.contains(use.getOwner()); }; for (auto x : nextLoopResults) { From a804c2dc5b5e901a4b2e510e4c9363d63e69fe58 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Wed, 28 Aug 2024 16:08:06 +0800 Subject: [PATCH 29/66] update transpose index --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 76 ++++++++++++++----- lib/gc/Transforms/LowerToTileVector.cpp | 2 +- 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 4800cadce..b1d71a8f8 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -30,7 +30,8 @@ namespace { linalg::MatmulOp, linalg::BatchMatmulOp, \ linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ - linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp + linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ + tensor::CollapseShapeOp, tensor::ExpandShapeOp // TODO: remove it in the future bool disableSpecialOp = false; @@ -60,9 +61,12 @@ void printGroupOps(SmallVector, 8> &opGroups) { } } +static inline bool isUsedByOtherOp(Operation *op) { + return isa(op); +} + static inline bool isCandidateMoveOperations(Operation *op) { - return isa(op); + return isa(op); } static inline bool isSpecialLinalgOp(Operation *op) { @@ -104,6 +108,14 @@ static inline bool isSpecialOp(Operation *op) { op); } +static inline void moveOpBeginingOfBlock(Operation *op) { + Block *block = op->getBlock(); + assert(not block->getOperations().empty() && "Empty block."); + if (&block->front() == op) + return; + op->moveBefore(&block->front()); +} + /// operation should not contain for loop bool is_innermost_operation(Operation *op) { bool inner_most = true; @@ -1937,7 +1949,7 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( loc, /*vectorType=*/kernelType, /*source=*/readSourceOp.getSource(), - /*indices=*/writeVars, + /*indices=*/inductionVars, /*padding=*/padValue, /*inBounds=*/inBoundsVal); @@ -3214,6 +3226,9 @@ bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { Operation *prevOp = nullptr; prevOp = getNotReadWriteOperaiton(tmpQ); if (!prevOp) { + if (opGroups.back().back()->getParentOp() != op->getParentOp()) { + return true; + } return false; } @@ -3479,11 +3494,19 @@ mlir::FailureOr getOperationOperateTensor(Operation *op) { return TypeSwitch>(op) .Case( [&](vector::TransferWriteOp transferWriteOp) { - LDBG(" DPS operation : " << *op << "\n"); - return transferWriteOp->getOperand(1); + // find original tensor.empty operation + auto writeTensor = transferWriteOp->getOperand(1); + while (auto wtOp = dyn_cast( + writeTensor.getDefiningOp())) { + if (transferWriteOp->getBlock() != + writeTensor.getDefiningOp()->getBlock()) { + break; + } + writeTensor = wtOp->getOperand(1); + } + return writeTensor; }) .Case([&](vector::TransferReadOp transferReadOp) { - LDBG(" DPS operation : " << *op << "\n"); return transferReadOp->getOperand(0); }) .Default([&](Operation *op) { @@ -3803,6 +3826,15 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { vector::TransferWriteOp>(sourceOp, sourceOpGid)) { auto writeOpresult = writeOp->getResults()[0]; auto writeTensor = writeOp->getOperands()[1]; + // find original tensor.empty operation + while (auto wtOp = dyn_cast( + writeTensor.getDefiningOp())) { + if (sourceOp->getBlock() != + writeTensor.getDefiningOp()->getBlock()) { + break; + } + writeTensor = wtOp->getOperand(1); + } srcOpCanoniclizedMap.insert( {sourceOp, {writeTensor, writeOpresult}}); groupOpInitArgs[sourceOpGid].insert(writeTensor); @@ -4054,22 +4086,23 @@ moveFront(Operation *op, } } if (allBlockArgs) { - rewriter.moveOpAfter(op, op->getBlock(), op->getBlock()->begin()); + moveOpBeginingOfBlock(op); return success(); } for (auto operand : op->getOperands()) { - if (isa(operand)) { + if (isa(operand)) continue; - } - if (operationPosition[operand.getDefiningOp()] > pos and - operand.getDefiningOp()->getBlock() == op->getBlock()) { - backOperation = operand.getDefiningOp(); - pos = operationPosition[operand.getDefiningOp()]; + + Operation *sourceOp = operand.getDefiningOp(); + if (operationPosition[sourceOp] > pos and + sourceOp->getBlock() == op->getBlock()) { + backOperation = sourceOp; + pos = operationPosition[sourceOp]; } } if (pos == 0) { // extract operand operation all in previous block - rewriter.moveOpBefore(op, op->getBlock(), op->getBlock()->begin()); + moveOpBeginingOfBlock(op); return success(); } if (backOperation) { @@ -4083,16 +4116,16 @@ LogicalResult moveBack(Operation *op, llvm::DenseMap &operationPosition) { IRRewriter rewriter(op); Operation *firstOperation; - size_t pos = 0; + size_t pos = std::numeric_limits::max(); for (auto user : op->getUsers()) { - if (operationPosition[user] > pos and user->getBlock() == op->getBlock()) { + if (operationPosition[user] < pos and user->getBlock() == op->getBlock()) { firstOperation = user; pos = operationPosition[user]; } } - if (pos == 0) { + if (pos == std::numeric_limits::max()) { // Don't move. - // TODO: need to consider move to before the block which use it. + // TODO: need to consider move before the block which use it. return success(); } if (firstOperation) { @@ -4196,7 +4229,10 @@ struct CPUPhysicalRegisterPass LDBG("Not support operation appears in current function."); return; } - std::function candidateFunc = isCandidateMoveOperations; + // affineApply operation is always used by other operations. + std::function candidateFunc = isUsedByOtherOp; + moveSomeInterferenceOperation(&func, ctx, candidateFunc); + candidateFunc = isCandidateMoveOperations; moveSomeInterferenceOperation(&func, ctx, candidateFunc); // canonicalize vector operation, default use vector-based fusion // strategy. diff --git a/lib/gc/Transforms/LowerToTileVector.cpp b/lib/gc/Transforms/LowerToTileVector.cpp index 5610cef00..c2e0c895d 100644 --- a/lib/gc/Transforms/LowerToTileVector.cpp +++ b/lib/gc/Transforms/LowerToTileVector.cpp @@ -656,7 +656,7 @@ struct LowerToTileVectorPass // transpose or braodcast semantic etc. vector::populateVectorTransferPermutationMapLoweringPatterns(secondPattern); // Remove unnessary broadcast operation - vector::populateSinkVectorBroadcastPatterns(secondPattern); + vector::populateSinkVectorOpsPatterns(secondPattern); // Second fold (with the help of the `applyPatternsAndFoldGreedily` // function) can help us to eliminate redundant operation like consecutive // read and write. From c1d7136a6cc4f7241a59ca24c18b83cd8c0749eb Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 29 Aug 2024 20:48:41 +0800 Subject: [PATCH 30/66] simplify analyzer code --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 349 +++++++++--------- {include => lib}/gc/Transforms/TilingVector.h | 16 +- 2 files changed, 182 insertions(+), 183 deletions(-) rename {include => lib}/gc/Transforms/TilingVector.h (96%) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index b1d71a8f8..82fcb3299 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -1,12 +1,11 @@ -//===- CPUPhysicalResigterPass.cpp.cpp - OneDNNGraph To Linalg -// Lowering -*- C++ -*-===// +//===- CPUPhysicalResigterPass.cpp - tiling as physical vector ---*-C++-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "gc/Transforms/TilingVector.h" +#include "TilingVector.h" namespace mlir { namespace gc { @@ -33,7 +32,7 @@ namespace { linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ tensor::CollapseShapeOp, tensor::ExpandShapeOp -// TODO: remove it in the future +/// TODO: remove it in the future bool disableSpecialOp = false; bool disableBroadcastOp = false; bool enableDebugPrinter = false; @@ -116,6 +115,18 @@ static inline void moveOpBeginingOfBlock(Operation *op) { op->moveBefore(&block->front()); } +/// find the original tensor +Value findOriginalTensor(Value writeTensor, Block *block) { + while (auto wtOp = dyn_cast_or_null( + writeTensor.getDefiningOp())) { + if (block != writeTensor.getDefiningOp()->getBlock()) + break; + + writeTensor = wtOp->getOperand(1); + } + return writeTensor; +} + /// operation should not contain for loop bool is_innermost_operation(Operation *op) { bool inner_most = true; @@ -2609,28 +2620,29 @@ void CanonicalizerVectorOperation::run() { auto &fusionStrategy = getFusionStrategy(); if (kind == CanonicalizerKind::OperationsGroup) { // 1. Analysis the operation's operands and results - // We need to analyze which operation results are needed by other + // We need to analyze which operation's result is needed by other // operations, and we need to pass these results correctly. Mapping the - // operation result value to forloop yeild result value. We can replace - // the operation operand as: map(operand, forloop yield result) -> operand - // = loop yield result We put all the operation result into this map. + // operation result value with the forloop yeild result value. We can + // replace the operation operand as: map(operand, forloop yield result) -> + // operand = loop yield result We put all the operation result into this + // map. // 1.a. Find results which should be generated by current group for // using as operands to other operations? // Traverse all operations. If the operand of operations in other groups - // or outside the group is the result of the current group operation, then - // the current operation needs to generate a result. We use `setvector` to - // save the results that need to be generated by the current group. + // or outside the group is the result of the operation in current group, + // then the current operation needs to generate a result. We use `setvector` + // to save the results that need to be generated by the current group. // 1.b. What operands are needed to find in the current group, and where // can they be obtained ? // Thanks to 1.a, we get the result generated by the operations of - // each group, and this result will use `for loop yield` to generate a + // each group, and this result will use `scf.yield` to generate a // new result. Since the scope of the parent block of mlir is covered // the current operation, the current operation does not need to pass - // these `for loop results` to the `iterArgs` of the required `for loop`. + // these `for loop result` to the `iterArgs` of the required `for loop`. // It only needs to replace the operand of the current operation with the // corresponding `for loop yield result`. @@ -2642,9 +2654,6 @@ void CanonicalizerVectorOperation::run() { // needs to be read from the tensor before the current operation operate // on it. Therefore, `empty tensor`, `transfer_write` and `transfer_read` // need to be inserted at target place. - - // Query groupResultYeildSet to map operaion result value to scf.yield - // result value. if (enableDebugPrinter) { printGroupOps(getFusionStrategy().getOpGroups()); llvm::outs() << "___________ before analysis ________________" @@ -3496,14 +3505,8 @@ mlir::FailureOr getOperationOperateTensor(Operation *op) { [&](vector::TransferWriteOp transferWriteOp) { // find original tensor.empty operation auto writeTensor = transferWriteOp->getOperand(1); - while (auto wtOp = dyn_cast( - writeTensor.getDefiningOp())) { - if (transferWriteOp->getBlock() != - writeTensor.getDefiningOp()->getBlock()) { - break; - } - writeTensor = wtOp->getOperand(1); - } + writeTensor = + findOriginalTensor(writeTensor, transferWriteOp->getBlock()); return writeTensor; }) .Case([&](vector::TransferReadOp transferReadOp) { @@ -3619,13 +3622,14 @@ void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( auto [tsr, writeOpresult] = canonicalizeSourceOperation(sourceOp, visitedOperation); auto writeOp = writeOpresult.getDefiningOp(); + assert(writeOp); srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); updateOpOperandResultInGroups(sourceOpGid, sourceOp, tsr, writeOpresult); groupOpInitArgs[sourceOpGid].insert(tsr); groupOpResults[sourceOpGid].insert({writeOpresult, {retKind, anchorPos}}); // write opeartion anchor pos is same with current operation getFusionStrategy().getOpAnchorPos()[writeOp] = - cast(writeOp).getVectorType().getRank() - 1; + writeOp.getVectorType().getRank() - 1; getOpPermuationMap()[writeOp] = writeOp.getPermutationMap(); } @@ -3633,6 +3637,9 @@ template Operation *CanonicalizerCommonUsedData::getNextTargetOperationInCurrentGroup( Operation *curOp, const size_t grpIdx) { std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[grpIdx]); + if (isa(curOp)) + return curOp; + while (!tmpOpQueue.empty()) { auto frontOp = tmpOpQueue.front(); if (isa(frontOp)) { @@ -3773,16 +3780,137 @@ void VectorOperationAnalyzer::updateReturnResultKind(Operation *sourceOp, } } +void VectorOperationAnalyzer::replaceConstantOpAsNewOp(Operation *op, + Operation *sourceOp, + size_t operandIdx) { + DenseMap &opGroupIndexMap = + getFusionStrategy().getOpGroupIndexMap(); + if (!opGroupIndexMap.contains(op)) { + return; + } + // TODO: add more operation to this case, write a constant value need + // to do this + if (isa(op) and operandIdx == 0) + return; + + if (isa(op)) { + if (operandIdx == 1) { + // accumulate value, just empty tensor is okay + auto resultTensor = getOperationResultTensor(sourceOp, visitedOperation); + auto opInit = canonicalizeCurrentOperation(op, resultTensor, operandIdx); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + return; + } + // source operation is the value + llvm::llvm_unreachable_internal( + "Need to add reduce constant operation optimization."); + } + + auto constantOp = cast(sourceOp); + IRRewriter rewriter(constantOp); + size_t groupSteps = + getFusionStrategy().getGroupMaxSteps()[opGroupIndexMap[op]]; + + if (isa(constantOp.getValue())) { + VectorType newOperandType = getVectorzedType(op, groupSteps); + auto valueType = cast(constantOp.getValue()); + if (valueType.isSplat()) { + FailureOr res = createArithSplatConstantOp( + rewriter, constantOp->getLoc(), valueType, newOperandType); + if (failed(res)) { + llvm::llvm_unreachable_internal("Wrong to create constant op."); + } + op->setOperand(operandIdx, res.value()); + // transfer read operation just use the constant value to do + // calculation, don't need to read. + if (isa(op) and operandIdx == 0) + removeOpInCurrentGroups(opGroupIndexMap[op], op, + op->getOperand(0).getDefiningOp()); + return; + } + llvm::llvm_unreachable_internal("Can't support not splat constant value."); + } +} + +void VectorOperationAnalyzer::makeSourceOpWriteResultToTensor( + Operation *sourceOp, size_t sourceOpGid, ReturnTypeKind rtKind) { + DenseMap &OpAnchorPos = + getFusionStrategy().getOpAnchorPos(); + SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); + + if (!srcOpCanoniclizedMap.contains(sourceOp)) { + // get write operation + if (Operation *writeOp = + getNextTargetOperationInCurrentGroup( + sourceOp, sourceOpGid)) { + auto writeOpresult = writeOp->getResults()[0]; + auto writeTensor = writeOp->getOperands()[1]; + // find original tensor.empty operation + writeTensor = findOriginalTensor(writeTensor, sourceOp->getBlock()); + srcOpCanoniclizedMap.insert({sourceOp, {writeTensor, writeOpresult}}); + groupOpInitArgs[sourceOpGid].insert(writeTensor); + updateReturnResultKind(writeOp, sourceOpGid, rtKind); + return; + } + generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap, + OpAnchorPos[sourceOp], rtKind, + visitedOperation); + return; + } + // udpate result return type + // updateReturnResultKind(sourceOp, sourceOpGid, rtKind); + updateReturnResultKind(srcOpCanoniclizedMap[sourceOp].second.getDefiningOp(), + sourceOpGid, rtKind); +} + +void VectorOperationAnalyzer::groupOperationNeedReturnResult( + size_t sourceOpGid, Operation *sourceOp, Operation *op, size_t operandIdx, + bool inSameGroupNeedReturn) { + ReturnTypeKind rtKind = inSameGroupNeedReturn ? ReturnTypeKind::RT_InGroup + : ReturnTypeKind::RT_OutGroup; + SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); + + DenseMap &opGroupIndexMap = + getFusionStrategy().getOpGroupIndexMap(); + // update init iterargs + auto dstRet = getOperationOperateTensor(sourceOp); + // need to generate tensor.emtpy and vector.transfer_write, write + // operand to tensor and read operand from the tensor, generate + // vector.transfer_read + if (failed(dstRet)) { + // already generate result tensor, special operation do the + // transformation by itself + if (isSpecialOp(sourceOp) and inSameGroupNeedReturn) { + return; + } + makeSourceOpWriteResultToTensor(sourceOp, sourceOpGid, rtKind); + auto opInit = canonicalizeCurrentOperation( + op, srcOpCanoniclizedMap[sourceOp].second, operandIdx); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + return; + } + // if source operation is transfer_read, we need to generate a + // same transfer_read operation like source operation. + if (isa(sourceOp)) { + auto transferReadOp = cast(sourceOp); + auto opInit = canonicalizeCurrentOperation(op, dstRet.value(), operandIdx, + &transferReadOp); + updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); + return; + } + // transfer write operation + groupOpInitArgs[sourceOpGid].insert(dstRet.value()); + updateReturnResultKind(sourceOp, sourceOpGid, rtKind); +} + void VectorOperationAnalyzer::analysisGroupOperaion() { // record the operation which has been moved DenseSet movedOperationSet; // record the operation's visited order, inorder to ensure set // correct operand size_t opCounter = 0; - DenseMap visitedOperation; DenseMap &opGroupIndexMap = getFusionStrategy().getOpGroupIndexMap(); - SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); DenseMap &OpAnchorPos = getFusionStrategy().getOpAnchorPos(); IRRewriter rewriter(func); @@ -3798,149 +3926,25 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { auto sourceOpGid = opGroupIndexMap[sourceOp]; bool notInSameGroup = opGroupIndexMap.contains(op) && sourceOpGid != opGroupIndexMap[op]; - bool outOfGroup = !opGroupIndexMap.contains(op); // Different anchor in same group and source operation is in inner // loop, we need to get source operation's result bool inSameGroupNeedReturn = !notInSameGroup and OpAnchorPos[sourceOp] > OpAnchorPos[op]; - ReturnTypeKind rtKind = inSameGroupNeedReturn - ? ReturnTypeKind::RT_InGroup - : ReturnTypeKind::RT_OutGroup; if (notInSameGroup or outOfGroup or inSameGroupNeedReturn) { - // update init iterargs - auto dstRet = getOperationOperateTensor(sourceOp); - // need to generate tensor.emtpy and vector.transfer_write, write - // operand to tensor and read operand from the tensor, generate - // vector.transfer_read - if (failed(dstRet)) { - // already generate result tensor, special operation do the - // transformation by itself - if (isSpecialOp(sourceOp) and inSameGroupNeedReturn) { - continue; - } - if (!srcOpCanoniclizedMap.contains(sourceOp)) { - // get write operation - if (auto writeOp = getNextTargetOperationInCurrentGroup< - vector::TransferWriteOp>(sourceOp, sourceOpGid)) { - auto writeOpresult = writeOp->getResults()[0]; - auto writeTensor = writeOp->getOperands()[1]; - // find original tensor.empty operation - while (auto wtOp = dyn_cast( - writeTensor.getDefiningOp())) { - if (sourceOp->getBlock() != - writeTensor.getDefiningOp()->getBlock()) { - break; - } - writeTensor = wtOp->getOperand(1); - } - srcOpCanoniclizedMap.insert( - {sourceOp, {writeTensor, writeOpresult}}); - groupOpInitArgs[sourceOpGid].insert(writeTensor); - updateReturnResultKind(writeOp, sourceOpGid, rtKind); - } else { - generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap, - OpAnchorPos[sourceOp], rtKind, - visitedOperation); - } - } else { - // udpate result return type - // updateReturnResultKind(sourceOp, sourceOpGid, rtKind); - updateReturnResultKind( - srcOpCanoniclizedMap[sourceOp].second.getDefiningOp(), - sourceOpGid, rtKind); - } - - auto opInit = canonicalizeCurrentOperation( - op, srcOpCanoniclizedMap[sourceOp].second, idx); - updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); - - } else { - // if source operation is transfer_read, we need to generate a - // same transfer_read operation like source operation. - if (isa(sourceOp)) { - auto transferReadOp = cast(sourceOp); - auto opInit = canonicalizeCurrentOperation(op, dstRet.value(), - idx, &transferReadOp); - updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); - - } else { - groupOpInitArgs[sourceOpGid].insert(dstRet.value()); - updateReturnResultKind(sourceOp, sourceOpGid, rtKind); - } - } - } - } else if (isa_and_nonnull(sourceOp)) { - if (!opGroupIndexMap.contains(op)) { - continue; - } - // TODO: add more operation to this case, write a constant value need - // to do this - if (isa(op)) { - if (idx == 0) - continue; - } - - if (isa(op)) { - if (idx == 1) { - // accumulate value, just empty tensor is okay - auto resultTensor = - getOperationResultTensor(sourceOp, visitedOperation); - auto opInit = canonicalizeCurrentOperation(op, resultTensor, idx); - updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); - continue; - } else { - // source operation is the value - llvm::llvm_unreachable_internal( - "Need to add reduce constant operation optimization."); - } - } - - auto constantOp = cast(sourceOp); - IRRewriter rewriter(constantOp); - size_t groupSteps = - getFusionStrategy().getGroupMaxSteps()[opGroupIndexMap[op]]; - - if (isa(constantOp.getValue())) { - VectorType newOperandType = getVectorzedType(op, groupSteps); - auto valueType = cast(constantOp.getValue()); - if (valueType.isSplat()) { - FailureOr res = createArithSplatConstantOp( - rewriter, constantOp->getLoc(), valueType, newOperandType); - if (failed(res)) { - llvm::llvm_unreachable_internal("Wrong to create constant op."); - } - op->setOperand(idx, res.value()); - } else { - llvm::llvm_unreachable_internal( - "Can't support not splat constant value."); - } - - // transfer read operation just use the constant value to do - // calculation, don't need to read. - if (isa(op)) { - if (idx == 0) { - removeOpInCurrentGroups(opGroupIndexMap[op], op, - op->getOperand(0).getDefiningOp()); - continue; - } - } + groupOperationNeedReturnResult(sourceOpGid, sourceOp, op, idx, + inSameGroupNeedReturn); } + continue; + } + if (isa_and_nonnull(sourceOp)) { + replaceConstantOpAsNewOp(op, sourceOp, idx); } } - // if (mlir::isa(op) && !movedOperationSet.contains(op)) - // { - // auto parentBlock = op->getBlock(); - // std::stack opStack; - - // op->moveBefore(parentBlock, parentBlock->getOperations().begin()); - // movedOperationSet.insert(op); - // } }); analysisEmptyGroup(); specialOperationRectify(visitedOperation); -#undef RESULT_RETURN_TYPE LDBG("Complete analysis group operation results\n"); } @@ -4011,28 +4015,6 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( return forOp; } -void updateLoopResultUses( - llvm::MapVector> &opResults, - scf::ForOp *forOp) { - if (opResults.empty()) { - return; - } - IRRewriter rewriter(*forOp); - OpBuilder::InsertionGuard g(rewriter); - // Only different group operation operand need to be replaced due to same - // group operation should directly use original operand. - - Operation *producerOp = opResults.begin()->first.getDefiningOp(); - auto needToReplaced = [&](OpOperand &operand) { - return producerOp->getBlock() != operand.getOwner()->getBlock(); - }; - // update loop result uses - for (auto [retIdx, rt] : llvm::enumerate(opResults)) { - rewriter.replaceUsesWithIf(rt.first, forOp->getResult(retIdx), - needToReplaced); - } -} - bool CanonicalizerCommonUsedData::isGroupHasSpecialOperation( const size_t grpIdx) { auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; @@ -4171,7 +4153,7 @@ void moveCandidateOperation( // Because those operation may interpret our analysis result. e.g.: // ``` // clang-format off - // %21 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> + // %21 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> // %22 = arith.addf %21, %20 : vector<16x16xf32> // %23 = vector.transfer_write %22, %extracted_slice_12[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> // %inserted_slice_13 = tensor.insert_slice %18 into %arg14[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> @@ -4237,11 +4219,14 @@ struct CPUPhysicalRegisterPass // canonicalize vector operation, default use vector-based fusion // strategy. HardWareInfo hwInfo; - // default has avx512f instructions - // hwInfo.favx512f = false; + CPUTargetDescriptionAnalysis sysDesc = + getAnalysis(); + hwInfo.favx512f = sysDesc.getMaxVectorWidth() == 512; + hwInfo.favx2 = sysDesc.getMaxVectorWidth() >= 256; CanonicalizerVectorOperation canonicalizer( func, CanonicalizerKind::OperationsGroup, hwInfo); canonicalizer.run(); + candidateFunc = isReadOrWriteOperation; moveSomeInterferenceOperation(&func, ctx, candidateFunc); diff --git a/include/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h similarity index 96% rename from include/gc/Transforms/TilingVector.h rename to lib/gc/Transforms/TilingVector.h index 4a6531ce8..4a6b8cde2 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -1,4 +1,4 @@ -//===- TilingVector.h - Tiling large vector to small vector ---*- C++ -*-===// +//===- TilingVector.h - Tiling large vector to small vector -----*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,6 +8,7 @@ #ifndef GC_PASSES_TILINGVECTOR_H #define GC_PASSES_TILINGVECTOR_H +#include "gc/Analysis/TargetDescriptionAnalysis.h" #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -634,6 +635,7 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { private: func::FuncOp func; DenseMap> srcOpCanoniclizedMap; + DenseMap visitedOperation; public: virtual ~VectorOperationAnalyzer(){}; @@ -653,6 +655,18 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { /// void updateReturnResultKind(Operation *sourceOp, size_t sourceOpGid, ReturnTypeKind rtKind); + + /// process the operation which need to return result + /// \param *op current operation + void groupOperationNeedReturnResult(size_t sourceOpGid, Operation *sourceOp, + Operation *op, size_t operandIdx, + bool inSameGroupNeedReturn); + /// source operation write it's result to a tensor + void makeSourceOpWriteResultToTensor(Operation *sourceOp, size_t sourceOpGid, + ReturnTypeKind rtKind); + /// analysis constant operation and replace it with a new constant operation + void replaceConstantOpAsNewOp(Operation *op, Operation *sourceOp, + size_t operandIdx); }; /// Vectorize vector operation with target machines simd instructions. class CanonicalizerVectorOperation : virtual public ForLoopGenerator, From 545cf9850fb93a81803ccd36e330dc78dc3c74c1 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 29 Aug 2024 21:04:22 +0800 Subject: [PATCH 31/66] fix clang-format --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 2 +- lib/gc/Transforms/TilingVector.h | 22 ++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 82fcb3299..7c39d5e6a 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -1,4 +1,4 @@ -//===- CPUPhysicalResigterPass.cpp - tiling as physical vector ---*-C++-*-===// +//===- CPUPhysicalRegisterPass.cpp - tiling as physical vector --*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index 4a6b8cde2..120a48f4f 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -112,16 +112,18 @@ class VectorFusionStrategy : public TypeHelper { VectorFusionStrategy(func::FuncOp &func) : func(func) {} VectorFusionStrategy(func::FuncOp &func, TypeHelper &typeHelper) : TypeHelper(typeHelper), func(func) {} + VectorFusionStrategy(VectorFusionStrategy &strategy) : func(strategy.func), opGroups(strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps), opGroupIndexMap(strategy.opGroupIndexMap), - opAnchorPos(strategy.opAnchorPos){}; + opAnchorPos(strategy.opAnchorPos) {}; + VectorFusionStrategy(VectorFusionStrategy &&strategy) : func(std::move(strategy.func)), opGroups(std::move(strategy.opGroups)), groupMaxSteps(std::move(strategy.groupMaxSteps)), opGroupIndexMap(std::move(strategy.opGroupIndexMap)), - opAnchorPos(std::move(strategy.opAnchorPos)){}; + opAnchorPos(std::move(strategy.opAnchorPos)) {}; VectorFusionStrategy &operator=(VectorFusionStrategy &&) = default; @@ -217,7 +219,7 @@ class MultiReductionCanonicalizer isStandaloneOp = candidateRdOps.size() == 1; prepareSpecialOperationInfo(); }; - virtual ~MultiReductionCanonicalizer(){}; + virtual ~MultiReductionCanonicalizer() {}; int64_t getTypeRank(); void getReductionAxisAndParallelAxis(); bool hasLastDimReduction(); @@ -255,7 +257,7 @@ class BroadcastCanonicalizer BroadcastCanonicalizer( const llvm::SmallVector &candidateBcOps) : SpecialOperationCanonicalizer( - candidateBcOps, SpecialOperationKind::OP_Broadcast){}; + candidateBcOps, SpecialOperationKind::OP_Broadcast) {}; virtual ~BroadcastCanonicalizer() {} void prepareSpecialOperationInfo() override {} static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -272,7 +274,7 @@ class TransposeCanonicalizer TransposeCanonicalizer( const llvm::SmallVector &candidateTpOps) : SpecialOperationCanonicalizer( - candidateTpOps, SpecialOperationKind::OP_Transpose){}; + candidateTpOps, SpecialOperationKind::OP_Transpose) {}; virtual ~TransposeCanonicalizer() {} void prepareSpecialOperationInfo() override; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -295,7 +297,7 @@ class ShapeCastCanonicalizer ShapeCastCanonicalizer( const llvm::SmallVector &candidateScOps) : SpecialOperationCanonicalizer( - candidateScOps, SpecialOperationKind::OP_ShapeCast){}; + candidateScOps, SpecialOperationKind::OP_ShapeCast) {}; virtual ~ShapeCastCanonicalizer() {} void prepareSpecialOperationInfo() override {} static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -331,7 +333,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { public: CanonicalizerCommonUsedData() = default; CanonicalizerCommonUsedData(VectorFusionStrategy &fusionStrategy) - : fusionStrategy(fusionStrategy){}; + : fusionStrategy(fusionStrategy) {}; CanonicalizerCommonUsedData( VectorFusionStrategy &fusionStrategy, @@ -342,7 +344,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { llvm::DenseMap &opPermuationMap) : fusionStrategy(fusionStrategy), groupOpResults(groupOpResults), groupOpInitArgs(groupOpInitArgs), opPermuationMap(opPermuationMap) {} - virtual ~CanonicalizerCommonUsedData(){}; + virtual ~CanonicalizerCommonUsedData() {}; /// Set fusion strategy void setFuseStrategy(VectorFusionStrategy &&strategy) { @@ -638,7 +640,7 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { DenseMap visitedOperation; public: - virtual ~VectorOperationAnalyzer(){}; + virtual ~VectorOperationAnalyzer() {}; VectorOperationAnalyzer() {} VectorOperationAnalyzer(func::FuncOp &func) : func(func) {} @@ -693,7 +695,7 @@ class CanonicalizerVectorOperation : virtual public ForLoopGenerator, setFuseStrategy(std::move(fusionStrategy)); } } - virtual ~CanonicalizerVectorOperation(){}; + virtual ~CanonicalizerVectorOperation() = default; // get functions func::FuncOp &getFunc() { return func; }; From 87e69e1e5de5f4a875d101fb7bb6aa15e11ad367 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 30 Aug 2024 10:42:33 +0800 Subject: [PATCH 32/66] simplify multireduction generate loop code --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 230 ++++++++++-------- lib/gc/Transforms/TilingVector.h | 29 ++- 2 files changed, 156 insertions(+), 103 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 7c39d5e6a..38cc60f95 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -23,14 +23,15 @@ namespace { arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp, \ arith::TruncFOp, arith::TruncIOp -#define IMPLEMENTED_MATMUL \ +#define NOT_NEED_TO_PROCESS_OP \ linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp, \ linalg::BatchReduceMatmulOp, linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp, \ linalg::MatmulOp, linalg::BatchMatmulOp, \ linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ - tensor::CollapseShapeOp, tensor::ExpandShapeOp + tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::ExtractSliceOp, \ + tensor::InsertSliceOp /// TODO: remove it in the future bool disableSpecialOp = false; @@ -65,11 +66,12 @@ static inline bool isUsedByOtherOp(Operation *op) { } static inline bool isCandidateMoveOperations(Operation *op) { - return isa(op); + return isa( + op); } -static inline bool isSpecialLinalgOp(Operation *op) { - return isa(op); +static inline bool isNotNeedToProcessOp(Operation *op) { + return isa(op); } static inline bool isReadOrWriteOperation(Operation *op) { @@ -1812,8 +1814,24 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( }); } -scf::ForOp -ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { +ValueRange ForLoopGenerator::prepareForLoopArgs( + const size_t grpIdx, DenseMap ¤tLoopStateIdxMap, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap) { + SetVector &grpArgs = getGroupOpInitArgs()[grpIdx]; + SmallVector forLoopArgs(grpArgs.begin(), grpArgs.end()); + ValueRange initArgs(forLoopArgs); + for (auto [idx, val] : llvm::enumerate(initArgs)) { + currentLoopStateIdxMap[val] = idx; + originalOperandLoopArgsMap[val] = val; + loopArgsOriginalOperandMap[val] = val; + } + return initArgs; +} + +void ForLoopGenerator::rearrageMultiReductionIR( + const size_t grpIdx, + DenseMap> &indiceLoopMap) { MultiReductionCanonicalizer &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; vector::MultiDimReductionOp multiReductionOp = @@ -1825,7 +1843,6 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { std::queue &accRelatedOps = rdCanonicalizer.getAccRelatedOps(); std::queue &sourceRelatedOps = rdCanonicalizer.getSourceRelatedOps(); - std::queue &opQueue = getFusionStrategy().getOpGroups()[grpIdx]; auto copyOpQueue(opQueue); getPrevOps(prevOps, copyOpQueue, multiReductionOp); @@ -1835,7 +1852,6 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { // mark source read operation need to set correct for loop var idx std::queue tmpSourceQ(sourceRelatedOps); - DenseMap> indiceLoopMap; DenseMap varLoopIdxMap; VectorType groupVector = getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; @@ -1861,9 +1877,9 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { while (!from.empty()) { auto cur = from.front(); from.pop(); - if (pushedSet.contains(cur)) { + if (pushedSet.contains(cur)) continue; - } + to.push(cur); pushedSet.insert(cur); } @@ -1871,23 +1887,50 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { moveOperation(accRelatedOps, rectifyQueue); moveOperation(opQueue, rectifyQueue); opQueue = rectifyQueue; +} + +void ForLoopGenerator::replaceOpUsersWithForLoopResult( + scf::ForOp forOp, int grpIdx, SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap) { + IRRewriter rewriter(func); + DenseSet forOpChildOps; + forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); + auto replaceIfFn = [&](OpOperand &use) { + return not forOpChildOps.contains(use.getOwner()); + }; + for (auto x : nextAnchorResults) { + auto originalResult = forResultOrignalResultMap[x]; + Value forResult = forOp->getResults()[nextAnchorResultsIdxMap[x]]; + rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), forResult, + replaceIfFn); + // subsequent group must use the replaced result as operand + rectifyGroupOperands(grpIdx, originalResult, forResult); + } +} + +scf::ForOp +ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { + MultiReductionCanonicalizer &rdCanonicalizer = + getMultiRdCanonicalizers()[grpIdx]; + vector::MultiDimReductionOp multiReductionOp = + rdCanonicalizer.getCandidateOps()[0]; + + DenseMap> indiceLoopMap; + rearrageMultiReductionIR(grpIdx, indiceLoopMap); // get current loop init args - SetVector &grpArgs = getGroupOpInitArgs()[grpIdx]; - SmallVector forLoopArgs(grpArgs.begin(), grpArgs.end()); - ValueRange initArgs(forLoopArgs); DenseMap currentLoopStateIdxMap; DenseMap nextAnchorResultsIdxMap; // map original operation operand with loop args DenseMap originalOperandLoopArgsMap, loopArgsOriginalOperandMap, forResultOrignalResultMap; - for (auto [idx, val] : llvm::enumerate(initArgs)) { - currentLoopStateIdxMap[val] = idx; - originalOperandLoopArgsMap[val] = val; - } + + ValueRange initArgs = prepareForLoopArgs(grpIdx, currentLoopStateIdxMap, + originalOperandLoopArgsMap, + loopArgsOriginalOperandMap); SmallVector inductionVars; - IRRewriter rewriter(func); OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); SmallVector nextAnchorResults; @@ -1895,20 +1938,11 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { opBuilder, grpIdx, 0, currentLoopStateIdxMap, initArgs, nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, loopArgsOriginalOperandMap, forResultOrignalResultMap, indiceLoopMap); - DenseSet forOpChildOps; - forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); - auto replaceIfFn = [&](OpOperand &use) { - return not forOpChildOps.contains(use.getOwner()); - }; - for (auto x : nextAnchorResults) { - auto originalResult = forResultOrignalResultMap[x]; - rewriter.replaceOpUsesWithIf( - originalResult.getDefiningOp(), - forOp->getResults()[nextAnchorResultsIdxMap[x]], replaceIfFn); - // following group must use the replaced result as operand - rectifyGroupOperands(grpIdx, originalResult, - forOp->getResults()[nextAnchorResultsIdxMap[x]]); - } + replaceOpUsersWithForLoopResult(forOp, grpIdx, nextAnchorResults, + nextAnchorResultsIdxMap, + forResultOrignalResultMap); + + IRRewriter rewriter(func); rewriter.eraseOp(multiReductionOp); return forOp; @@ -2559,26 +2593,31 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { while (!tempQ.empty()) { auto op = tempQ.front(); tempQ.pop(); - if (isa(op)) { - getMultiRdCanonicalizers().back().getCandidateOps().emplace_back( - cast(op)); - getMultiRdCanonicalizers().back().prepareSpecialOperationInfo(); - } else if (isa(op)) { - getBroadcastCanonicalizers().back().getCandidateOps().emplace_back( - cast(op)); - } else if (isa(op)) { - getTransposeCanonicalizers().back().getCandidateOps().emplace_back( - cast(op)); - } else if (isa(op)) { - getShapeCastCanonicalizers().back().getCandidateOps().emplace_back( - cast(op)); - } + TypeSwitch(op) + .Case([&](vector::MultiDimReductionOp + multiReductionOp) { + getMultiRdCanonicalizers().back().getCandidateOps().emplace_back( + cast(op)); + getMultiRdCanonicalizers().back().prepareSpecialOperationInfo(); + }) + .Case([&](vector::BroadcastOp broadCastOp) { + getBroadcastCanonicalizers().back().getCandidateOps().emplace_back( + cast(op)); + }) + .Case([&](vector::TransposeOp tpOp) { + getTransposeCanonicalizers().back().getCandidateOps().emplace_back( + cast(op)); + }) + .Case([&](vector::ShapeCastOp spOp) { + getShapeCastCanonicalizers().back().getCandidateOps().emplace_back( + cast(op)); + }) + .Default([&](Operation *op) {}); } } } void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { - // multireduction operation OpBuilder::InsertionGuard guard(rewriter); initSpeicalOperationCanonicalizers(); @@ -3015,11 +3054,6 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { return TypeSwitch(op) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { - // auto rdDimsRange = multiReductionOp.getReductionDims() - // .getAsValueRange(); - // auto reductionDims = llvm::to_vector<4>(map_range( - // rdDimsRange, [](const APInt &a) { return a.getZExtValue(); - // })); auto rdDimsRange = multiReductionOp.getReductionDims(); dataAxis.assign(rdDimsRange.begin(), rdDimsRange.end()); }) @@ -3303,12 +3337,12 @@ void VectorFusionStrategy::classifyOperations() { func->walk([&](Operation *op) { if (filterOperation(op)) { addOperationToGroup(op); - } else if (isSpecialLinalgOp(op)) { - // following operation need a new group - if (!opGroups.back().empty()) { - opGroups.emplace_back(std::queue()); - } + return WalkResult::advance(); + } + if (isNotNeedToProcessOp(op) and !opGroups.back().empty()) { + opGroups.emplace_back(std::queue()); } + return WalkResult::advance(); }); } @@ -3323,7 +3357,7 @@ Value setOutGroupOperationOperandResult(Operation *op, auto value = constantOp.getValue(); Attribute initValueAttr; - if (mlir::isa(value)) { + if (isa(value)) { auto valueType = mlir::dyn_cast(value); if (valueType.isSplat()) { if (mlir::isa(valueType.getElementType())) { @@ -3660,16 +3694,13 @@ void VectorOperationAnalyzer::analysisEmptyGroup() { SmallVector>, 8> &groupOpResults = getGroupOpResults(); for (auto [idx, grp] : llvm::enumerate(opGroups)) { - if (grp.empty()) { + if (grp.empty()) continue; - } - if (groupOpResults[idx].empty()) { + if (groupOpResults[idx].empty()) std::queue().swap(grp); - } } } -/// get each operation in each group maximum support vectorization length void VectorOperationAnalyzer::analysisGroupMaxSteps() { auto &opGroups = getFusionStrategy().getOpGroups(); @@ -3691,9 +3722,9 @@ void VectorOperationAnalyzer::analysisGroupMaxSteps() { while (!tmpQueue.empty()) { auto op = tmpQueue.front(); tmpQueue.pop(); - if (isa(op)) { + if (isa(op)) calculateOpSteps(op->getOperandTypes()[0]); - } + calculateOpSteps(getOperationVectorType(op).value()); } grpSteps[idx] = steps; @@ -3715,38 +3746,38 @@ void VectorOperationAnalyzer::specialOperationRectify( if (isa(op) and !disableBroadcastOp) { auto srcOp = op->getOperand(0).getDefiningOp(); assert(isa(srcOp)); - // just remain write operation, it's size will - // bigger than 1 if not write operation. Because the last operation - // always be write in each group - if (tmpQueue.size() <= 1) { + // only have write operation, otherwise the group size will bigger + // than 1. Because the last operation is always a write operation in + // each group + if (tmpQueue.size() <= 1) continue; - } + getFusionStrategy().getOpAnchorPos()[srcOp] = getFusionStrategy().getOpAnchorPos()[op]; rewriter.replaceOp(op, srcOp); continue; } - // anchor of multidim reduciton rectify + // anchor of multidim reduction rectify if (isa(op)) { auto accSourceOp = op->getOperand(1).getDefiningOp(); getFusionStrategy().getOpAnchorPos()[accSourceOp] = getOperationVectorType(accSourceOp)->getRank() - 1; } - // case: - // %1 = some op - // %2 = tensor.empty() - // %3 = vector.transfer_write %1, %2 - // -> move emtpy operation before %1 for better generate %1 - if (isa(op)) { - auto srcOp = op->getOperand(1).getDefiningOp(); - if (isa_and_nonnull(srcOp)) { - Operation *writeVectorOp = op->getOperands()[0].getDefiningOp(); - if (visitedOperation[srcOp] >= visitedOperation[writeVectorOp]) { - srcOp->moveBefore(writeVectorOp); - visitedOperation[srcOp] = visitedOperation[writeVectorOp]; - } - } - } + // // case: + // // %1 = some op + // // %2 = tensor.empty() + // // %3 = vector.transfer_write %1, %2 + // // -> move emtpy operation before %1 for better generate %1 + // if (isa(op)) { + // auto srcOp = op->getOperand(1).getDefiningOp(); + // if (isa_and_nonnull(srcOp)) { + // Operation *writeVectorOp = op->getOperands()[0].getDefiningOp(); + // if (visitedOperation[srcOp] >= visitedOperation[writeVectorOp]) { + // srcOp->moveBefore(writeVectorOp); + // visitedOperation[srcOp] = visitedOperation[writeVectorOp]; + // } + // } + // } newQueue.push(op); } getFusionStrategy().getOpGroups()[idx] = newQueue; @@ -3858,7 +3889,6 @@ void VectorOperationAnalyzer::makeSourceOpWriteResultToTensor( return; } // udpate result return type - // updateReturnResultKind(sourceOp, sourceOpGid, rtKind); updateReturnResultKind(srcOpCanoniclizedMap[sourceOp].second.getDefiningOp(), sourceOpGid, rtKind); } @@ -3954,20 +3984,19 @@ void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, size_t totalGroupSize = getFusionStrategy().getOpGroups().size(); size_t startGroup = currentGroupId; while (startGroup < totalGroupSize) { - SetVector &operandVector = getGroupOpInitArgs()[startGroup]; - if (operandVector.contains(originalResult)) { - SetVector replacedVector; + SetVector &operandVector = getGroupOpInitArgs()[startGroup++]; + if (not operandVector.contains(originalResult)) + continue; + SetVector replacedVector; - for (auto v : operandVector) { - if (v == originalResult) { - replacedVector.insert(forResult); - } else { - replacedVector.insert(v); - } + for (auto v : operandVector) { + if (v == originalResult) { + replacedVector.insert(forResult); + continue; } - getGroupOpInitArgs()[startGroup] = replacedVector; + replacedVector.insert(v); } - startGroup++; + getGroupOpInitArgs()[startGroup - 1] = replacedVector; } } @@ -4131,6 +4160,9 @@ void moveCandidateOperation( [&](tensor::ExtractSliceOp extractOp) { return moveFront(op, operationPosition); }) + .Case([&](tensor::EmptyOp emptyOp) { + return moveFront(op, operationPosition); + }) .Case([&](tensor::InsertSliceOp insertOp) { return moveBack(op, operationPosition); }) diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index 120a48f4f..e1bcaf292 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -306,6 +306,8 @@ class ShapeCastCanonicalizer bool isReadWriteOnLastDim(); }; +/// operation return kind, which is used to determine whether the operation need +/// to return it's result in current for loop enum class ReturnTypeKind { RT_Both, RT_OutGroup, @@ -440,7 +442,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { Operation *getNextTargetOperationInCurrentGroup(Operation *curOp, const size_t grpIdx); }; - +/// generate for loop for each operation. class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { private: func::FuncOp func; @@ -454,6 +456,19 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { void clearCurrentOperationGroup(size_t grpIdx); void generateGroupOpVectorizedIR(const int idx); + /// prepare for loop iteration args + ValueRange + prepareForLoopArgs(const size_t grpIdx, + DenseMap ¤tLoopStateIdxMap, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap); + + /// replace original operation result with corresponding for loop result + void replaceOpUsersWithForLoopResult( + scf::ForOp forOp, int grpIdx, SmallVector &nextAnchorResults, + DenseMap &nextAnchorResultsIdxMap, + DenseMap &forResultOrignalResultMap); + /// mark which operation need to set correct for loop var idx /// due to sometimes we need to chage for loop order like reduce operation. void getCurrentGroupIndiceLoopMap( @@ -563,6 +578,11 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { DenseMap &forResultOrignalResultMap); // multireduction forloop methods scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); + /// Rearrange the current opIR to facilitate the generation of the correct + /// reduction IR + void rearrageMultiReductionIR( + const size_t grpIdx, + DenseMap> &indiceLoopMap); scf::ForOp reductionAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, const int anchorIdx, llvm::DenseMap ¤tLoopStateIdxMap, @@ -640,21 +660,22 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { DenseMap visitedOperation; public: - virtual ~VectorOperationAnalyzer() {}; - VectorOperationAnalyzer() {} + virtual ~VectorOperationAnalyzer() = default; + VectorOperationAnalyzer() = default; VectorOperationAnalyzer(func::FuncOp &func) : func(func) {} void setAnalysisFunc(func::FuncOp &func) { this->func = func; } /// remove the useless operation, due to it result is not require by other // operation void analysisEmptyGroup(); + /// get each operation in each group maximum support vectorization length void analysisGroupMaxSteps(); /// analysis operation result of current group whether needed by other /// operation void analysisGroupOperaion(); void specialOperationRectify(DenseMap &visitedOperation); - /// + /// update operation result kind void updateReturnResultKind(Operation *sourceOp, size_t sourceOpGid, ReturnTypeKind rtKind); From dc00e70f66bbaf428e0b248d815694f650da700d Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 30 Aug 2024 14:35:54 +0800 Subject: [PATCH 33/66] update test --- test/gc/transforms/linalg-vectorization.mlir | 99 ------------ .../Transforms}/cpu-vetor-distribution.mlir | 142 +++++++++++------- 2 files changed, 89 insertions(+), 152 deletions(-) delete mode 100644 test/gc/transforms/linalg-vectorization.mlir rename test/{gc/transforms => mlir/test/gc/Transforms}/cpu-vetor-distribution.mlir (75%) diff --git a/test/gc/transforms/linalg-vectorization.mlir b/test/gc/transforms/linalg-vectorization.mlir deleted file mode 100644 index a3b68a92d..000000000 --- a/test/gc/transforms/linalg-vectorization.mlir +++ /dev/null @@ -1,99 +0,0 @@ -// RUN: gc-opt --split-input-file -pass-pipeline='builtin.module(func.func(lower-to-tile-vector))' --mlir-print-ir-after-all -- %s - -// CHECK-LABEL: func @add_tensor -func.func @add_tensor_test0(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { - %0 = tensor.empty() : tensor<4x8x16xf32> - %1 = linalg.add ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> - return %1 : tensor<4x8x16xf32> -} - -func.func @add_tensor_test1(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<1x8x8xf32> { - %0 = tensor.empty() : tensor<1x8x8xf32> - %1 = tensor.extract_slice %arg0[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<4x8x16xf32> to tensor<1x8x8xf32> - %2 = tensor.extract_slice %arg1[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<4x8x16xf32> to tensor<1x8x8xf32> - %3 = linalg.add ins(%1, %2 : tensor<1x8x8xf32>, tensor<1x8x8xf32>) outs(%0: tensor<1x8x8xf32>) -> tensor<1x8x8xf32> - return %3 : tensor<1x8x8xf32> -} - -func.func @add_tensor_pack_test2(%arg0: tensor<4x16x16xf32>, %arg1: tensor<4x16x16xf32>) -> tensor<4x4x4x4x4xf32> { - %0 = tensor.empty() : tensor<4x4x4x4x4xf32> - %1 = tensor.empty() : tensor<4x4x4x4x4xf32> - %2 = tensor.pack %arg0 outer_dims_perm = [1, 0, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 4] into %0 : tensor<4x16x16xf32> -> tensor<4x4x4x4x4xf32> - %3 = tensor.pack %arg1 outer_dims_perm = [1, 0, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 4] into %1 : tensor<4x16x16xf32> -> tensor<4x4x4x4x4xf32> - %4 = tensor.empty() : tensor<4x4x4x4x4xf32> - %6 = linalg.add ins(%2, %3 : tensor<4x4x4x4x4xf32>, tensor<4x4x4x4x4xf32>) outs(%4: tensor<4x4x4x4x4xf32>) -> tensor<4x4x4x4x4xf32> - return %6 : tensor<4x4x4x4x4xf32> -} - -func.func @add_tensor_pad_test3(%arg0: tensor<4x16x15xf32>, %arg1: tensor<4x16x15xf32>) -> tensor<4x16x16xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.pad %arg0 low[0, 0, 0] high[0, 0, 1] { - ^bb0(%arg2: index, %arg3: index, %arg4: index): - tensor.yield %cst : f32 - } : tensor<4x16x15xf32> to tensor<4x16x16xf32> - %1 = tensor.pad %arg1 low[0, 0, 0] high[0, 0, 1] { - ^bb0(%arg5: index, %arg6: index, %arg7: index): - tensor.yield %cst : f32 - } : tensor<4x16x15xf32> to tensor<4x16x16xf32> - %2 = tensor.empty() : tensor<4x16x16xf32> - %3 = linalg.add ins(%0, %1 : tensor<4x16x16xf32>, tensor<4x16x16xf32>) outs(%2: tensor<4x16x16xf32>) -> tensor<4x16x16xf32> - return %3 : tensor<4x16x16xf32> -} - -func.func @add_tensor_test4(%arg0: tensor<12x2x56x56x32xf32>, %arg1: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> { - %0 = tensor.empty() : tensor<12x56x56x64xf32> - %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> - %2 = tensor.empty() : tensor<12x56x56x64xf32> - %3 = tensor.unpack %arg1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %2 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> - %4 = tensor.empty() : tensor<12x56x56x64xf32> - %5 = linalg.add ins(%1, %3 : tensor<12x56x56x64xf32>, tensor<12x56x56x64xf32>) outs(%4: tensor<12x56x56x64xf32>) -> tensor<12x56x56x64xf32> - return %5 : tensor<12x56x56x64xf32> -} - -func.func @add_tensor_test5() -> tensor<1x1x1x8xf32> { - %cst = arith.constant 1.000000e+00 : f32 - %init = tensor.empty() : tensor<1x8xf32> - %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x8xf32>) -> tensor<1x8xf32> - %slice = tensor.extract_slice %fill[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> - %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor<8xf32> into tensor<1x1x1x8xf32> - return %expand : tensor<1x1x1x8xf32> -} - -func.func @tensor_collapse_shape_test0(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { - %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xf32> into tensor<6xf32> - return %0 : tensor<6xf32> -} - -func.func @tensor_bitcast_test0(%input: tensor<2xi32>) -> tensor<2xf32> { - %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32> - %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32> - return %1 : tensor<2xf32> -} - -func.func @tensor_static_concat_test0(%arg0 : tensor<1x1x64xf32>, - %arg1: tensor<1x1x64xf32>) -> tensor<1x1x128xf32> { - %0 = tensor.concat dim(2) %arg0, %arg1 - : (tensor<1x1x64xf32>, tensor<1x1x64xf32>) -> tensor<1x1x128xf32> - return %0 : tensor<1x1x128xf32> -} - -func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, - %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) - -> tensor<512x512xf32> { - // Matrix-matrix multiplication. - %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) - outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> - - // Elementwise addition. - %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } - ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) - outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> - - // Elementwise max with 0 (ReLU). - %c0f = arith.constant 0.0 : f32 - // expected-remark @below {{elementwise binary}} - %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } - ins(%biased, %c0f : tensor<512x512xf32>, f32) - outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> - func.return %relued : tensor<512x512xf32> -} diff --git a/test/gc/transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir similarity index 75% rename from test/gc/transforms/cpu-vetor-distribution.mlir rename to test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir index b6310946a..7a2eac7a8 100644 --- a/test/gc/transforms/cpu-vetor-distribution.mlir +++ b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir @@ -1,28 +1,51 @@ // RUN: gc-opt %s --split-input-file --fold-tensor-operation --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all | FileCheck %s // CHECK-LABEL: func @add_tensor_test0 +// CHECK: %[[C4096:.*]] = arith.constant 4096 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C11008:.*]] = arith.constant 11008 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[TENSOR0:.*]] = tensor.empty() : tensor<11008x4096xf32> +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<11008x4096xf32> func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { - // CHECK: %[[C4096:.*]] = arith.constant 4096 : index - // CHECK: %[[C16:.*]] = arith.constant 16 : index - // CHECK: %[[C11008:.*]] = arith.constant 11008 : index - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 - // CHECK: %[[TENSOR0:.*]] = tensor.empty() : tensor<11008x4096xf32> - // CHECK: scf.for - // CHECK: scf.for - // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> - // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> - // CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> - // CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[READ1]] : vector<16xf32> - // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<11008x4096xf32> + %0 = tensor.empty() : tensor<11008x4096xf32> %1 = linalg.add ins(%arg0, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> %2 = linalg.add ins(%1, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> return %2 : tensor<11008x4096xf32> } -func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { +// CHECK-LABEL: func @reduce_keepdimtest1 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x64xf32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<16x1x64xf32> +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: scf.for +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], {{.*}} : vector<16xf32> +// CHECK: scf.yield +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> +// CHECK: scf.yield +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +func.func @reduce_keepdimtest1(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { %0 = tensor.empty() : tensor<16x64xf32> %reduce = linalg.reduce ins(%arg0:tensor<16x32x64xf32>) @@ -36,18 +59,8 @@ func.func @reduce_keepdim0(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { return %2 : tensor<16x1x64xf32> } -// // CHECK-LABEL: func @fc_relu -func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, - %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) - -> tensor<512x512xf32> { - // CHECK: scf.for - // CHECK: scf.for - // CHECK: scf.for - // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512x512xf32>, vector<16xf32> - // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512x512xf32>, vector<16xf32> - // CHECK: %[[MULF0:.*]] = arith.mulf %[[READ0]], %[[READ1]] : vector<16xf32> - // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512x512xf32> - // CHECK-DAG: vector.multi_reduction +// CHECK-LABEL: func @fc_relu_test2 + // CHECK: %[[MATMUL:.*]] = linalg.matmul // CHECK: scf.for // CHECK: scf.for // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> @@ -55,46 +68,37 @@ func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> // CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], {{.*}} : vector<16xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512xf32> +func.func @fc_relu_test2(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { + // Matrix-matrix multiplication. %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> - // Elementwise addition. %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> - // Elementwise max with 0 (ReLU). %c0f = arith.constant 0.0 : f32 - // expected-remark @below {{elementwise binary}} %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } ins(%biased, %c0f : tensor<512x512xf32>, f32) outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> func.return %relued : tensor<512x512xf32> } -// CHECK-LABEL: func @matmul_add -func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { - // CHECK: vector.broadcast - // CHECK: vector.transpose - // CHECK: vector.broadcast - // CHECK: vector.transpose - // CHECK: scf.for - // CHECK: scf.for - // CHECK: scf.for - // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<32x32x12288xf16>, vector<16xf16> - // CHECK: %[[EXTF0:.*]] = arith.extf %[[READ0]] : vector<16xf16> to vector<16xf32> - // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} : tensor<32x32x12288xf16>, vector<16xf16> - // CHECK: %[[EXTF1:.*]] = arith.extf %[[READ1]] : vector<16xf16> to vector<16xf32> - // CHECK: %[[MULF0:.*]] = arith.mulf %[[EXTF0]], %[[EXTF1]] : vector<16xf32> - // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32x12288xf32> - // CHECK-DAG: vector.multi_reduction - // CHECK: scf.for - // CHECK: scf.for - // CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<8192x16384xf32>, vector<16xf32> - // CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<32x32xf32>, vector<16xf32> - // CHECK: %[[ADD0:.*]] = arith.addf %[[READ3]], %[[READ2]] : vector<16xf32> - // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<128x128xf32> +// CHECK-LABEL: func @matmul_add_test3 +// CHECK: %[[MATMUL0:.*]] = linalg.matmul +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32xf32> +// CHECK: scf.yield +// CHECK: scf.yield +func.func @matmul_add_test3(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { + %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> %1 = tensor.empty() : tensor<8192x16384xf32> %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { @@ -155,12 +159,13 @@ func.func @matmul_add(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf return %3 : tensor<8192x16384xf32> } +// CHECK-LABEL: func @fuse_mlp_test4 #map = affine_map<(d0) -> (d0 * 64)> #map1 = affine_map<(d0) -> (d0 * 128)> #map2 = affine_map<(d0) -> (d0 * 4)> #map3 = affine_map<(d0) -> (d0 floordiv 16)> #map4 = affine_map<(d0) -> (d0 floordiv 32)> -func.func @fuse_mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { +func.func @fuse_mlp_test4(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { %c32 = arith.constant 32 : index %c512 = arith.constant 512 : index %c128 = arith.constant 128 : index @@ -203,6 +208,37 @@ func.func @fuse_mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%14, %15, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> %extracted_slice_19 = tensor.extract_slice %extracted_slice_9[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> %unpack = tensor.unpack %extracted_slice_18 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_19 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C512:.*]] = arith.constant 512 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<128x256xbf16> +// CHECK: scf.forall +// CHECK-COUNT-6: scf.for +// CHECK-COUNT-4: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x1x16x32xbf16>, vector<1xbf16> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xbf16>, tensor<32x16x1x32xbf16> +// CHECK: %[[FILL0:.*]] = linalg.fill +// CHECK-COUNT-3: scf.for +// CHECK: %[[APPLY0:.*]] = affine.apply +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x512xbf16>, vector<32xbf16> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x32x512xbf16> +// CHECK-COUNT-4: scf.for +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x16x1x32xbf16>, vector<32xbf16> +// CHECK: %[[APPLY1:.*]] = affine.apply +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x512x32xbf16> +// CHECK: %[[MATMUL0:.*]] = linalg.batch_reduce_matmul +// CHECK-COUNT-2: scf.for +// CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xbf16>, vector<32xbf16> +// CHECK: %[[READ4:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32xbf16>, vector<32xbf16> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ3]], %[[READ4]] : vector<32xbf16> +// CHECK: %[[EXP0:.*]] = math.exp %[[ADD0]] : vector<32xbf16> +// CHECK: %[[WRITE3:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> +// CHECK: %[[WRITE4:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> %extracted_slice_20 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> %16 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> %expanded = tensor.expand_shape %extracted_slice_17 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> From 7ac81c7ff965a883f7c791749e4f4931adb526f1 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 2 Sep 2024 09:23:11 +0800 Subject: [PATCH 34/66] simplify code --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 106 +-- lib/gc/Transforms/TilingVector.h | 27 +- .../gc/Transforms/cpu-vetor-distribution.mlir | 659 +++++++++++------- 3 files changed, 464 insertions(+), 328 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 38cc60f95..25dc2a436 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -220,14 +220,14 @@ int TypeHelper::generateValidSteps(int steps, VectorType type) { const int favx2bits = 256; if (HWInfo.favx512f) { return favx512bits / typebits; - } else if (HWInfo.favx2) { + } + if (HWInfo.favx2) { return favx2bits / typebits; - } else { - // invalid hardware - LDBG("Please check the hardware information."); - assert(false && "Invalid hardware."); - return -1; } + // invalid hardware + LDBG("Please check the hardware information."); + assert(false && "Invalid hardware."); + return -1; } /// Get a appropriate for loop step for current vector type @@ -1742,8 +1742,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const int tpSteps, const Location &loc, SmallVector &inductionVars, - const ValueRange &iterArgs, DenseMap &operandIdxMap, - DenseMap &originalOperandMap) { + ValueRange iterArgs, DenseMap &operandIdxMap, + DenseMap &originalOperandMap, Operation *successorWriteOp) { auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; VectorType vtType = tpOp.getVector().getType(); @@ -1770,12 +1770,6 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( Value source = tpOp->getOperand(0); auto readSourceOp = cast(source.getDefiningOp()); - vector::TransferWriteOp successorWriteOp; - for (Operation *x : tpOp->getUsers()) { - if (isa(x)) { - successorWriteOp = cast(x); - } - } auto padValue = b.create( loc, b.getZeroAttr(vtType.getElementType())); SmallVector inBoundsVal(2, true); @@ -1808,7 +1802,7 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( // outter loop auto nxtFor = generateTransposeForLoopWithLastDim( b, grpIdx, forDimIdx + 1, tpSteps, loc, inductionVars, loopState, - operandIdxMap, originalOperandMap); + operandIdxMap, originalOperandMap, successorWriteOp); maybeYieldValue(b, loc, nxtFor->getResults()); } }); @@ -2014,7 +2008,7 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( }); } -// generate simple data movement for loop +/// generate simple data movement for loop scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const size_t steps, const Location &loc, SmallVector &inductionVars, @@ -2173,19 +2167,19 @@ void ForLoopGenerator::rectifyWriteOperationIndice( void ForLoopGenerator::rectifyReadOperationIndice( vector::TransferReadOp *originalReadOp, VectorType loopType, ArrayRef inductionVars, SmallVectorImpl &readVars) { - VectorType originalReadVectorType = originalReadOp->getVectorType(); + ShapedType readTensorType = + cast(originalReadOp->getSource().getType()); // currently only broadcast (fuse as transfer_read) will move into more inner // loop - // TODO: Need to better process the broadcast operation - if (originalReadVectorType.getRank() - 1 < + if (readTensorType.getRank() - 1 >= getFusionStrategy().getOpAnchorPos()[*originalReadOp]) { return; } + int64_t itrIdx = loopType.getRank() - 1; - int64_t readIdx = originalReadVectorType.getRank() - 1; + int64_t readIdx = readTensorType.getRank() - 1; while (itrIdx >= 0 and readIdx >= 0) { - if (originalReadVectorType.getShape()[readIdx] == - loopType.getShape()[itrIdx]) { + if (readTensorType.getShape()[readIdx] == loopType.getShape()[itrIdx]) { readVars[readIdx] = inductionVars[itrIdx]; readIdx--; } @@ -2272,7 +2266,6 @@ void ForLoopGenerator::clearCurrentOperationGroup(size_t grpIdx) { std::queue().swap(getFusionStrategy().getOpGroups()[grpIdx]); }; -/// generate transpose for loop scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { // transpose rank must bigger than 2 @@ -2294,28 +2287,18 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { DenseSet permuteSet(permutation.begin(), permutation.end()); bool isTwoDTranspose = tpCanonicalizer.isTwoDTranspose(); - vector::TransferWriteOp successorWriteOp; - for (Operation *x : tpOp->getUsers()) { - if (isa(x)) { - successorWriteOp = cast(x); - break; - } - } - // iterArgs.emplace_back(successorWriteOp->getOperands()[1]); - SmallVector operands; + Operation *successorWriteOp = + getNextTargetOperationInCurrentGroup(tpOp, + grpIdx); + DenseMap operandIdxMap; DenseMap originalOperandMap, operandOriginalMap, resultIdxMap, forResultOrignalResultMap; - SetVector &initArgs = getGroupOpInitArgs()[grpIdx]; - for (Value x : initArgs) { - operands.emplace_back(x); - operandIdxMap[x] = operands.size() - 1; - originalOperandMap[x] = x; - operandOriginalMap[x] = x; - } - SmallVector iterArgs(operands.begin(), operands.end()); + SmallVector iterArgs = prepareForLoopArgs( + grpIdx, operandIdxMap, originalOperandMap, operandOriginalMap); SmallVector inductionVars; + // TODO: need to process transpose on all one dim // don't need to do the transpose // if (tpCanonicalizer.isTransposeOnAllOneDim()) { // removeOpInCurrentGroups(grpIdx, tpOp, @@ -2366,7 +2349,7 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { isTwoDTranspose) { scf::ForOp forOp = generateTransposeForLoopWithLastDim( b, grpIdx, 0, tpStep, tpOp.getLoc(), inductionVars, iterArgs, - operandIdxMap, originalOperandMap); + operandIdxMap, originalOperandMap, successorWriteOp); rewriter.replaceOp(successorWriteOp, forOp); // clear current group operation @@ -3280,11 +3263,6 @@ bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { return true; } - // read and write operation dependency - // if (readWriteDependency(prevOp, op)) { - // return true; - // } - // special operation need to check data dependency axis if (hasDataDependency(prevOp, op)) { return true; @@ -3622,11 +3600,15 @@ void CanonicalizerCommonUsedData::updateOpOperandResultInGroups( getFusionStrategy().getOpAnchorPos()[op]; } // directly use the read operation to do the fusion - if (isa(op) and !tmpOpQueue.empty()) { - IRRewriter rewrite(op); - rewrite.replaceOp(op, op->getOperand(0).getDefiningOp()); - continue; - } + // if (isa(op) and not disableBroadcastOp) { + // auto srcOp = op->getOperand(0).getDefiningOp(); + // getFusionStrategy().getOpAnchorPos()[srcOp] = + // getFusionStrategy().getOpAnchorPos()[op]; + + // IRRewriter rewrite(op); + // rewrite.replaceOp(op, srcOp); + // continue; + // } newOpQueue.push(op); if (result && !failed(getOperationVectorType(result.getDefiningOp()))) { @@ -3656,7 +3638,6 @@ void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( auto [tsr, writeOpresult] = canonicalizeSourceOperation(sourceOp, visitedOperation); auto writeOp = writeOpresult.getDefiningOp(); - assert(writeOp); srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); updateOpOperandResultInGroups(sourceOpGid, sourceOp, tsr, writeOpresult); groupOpInitArgs[sourceOpGid].insert(tsr); @@ -3743,17 +3724,15 @@ void VectorOperationAnalyzer::specialOperationRectify( auto op = tmpQueue.front(); tmpQueue.pop(); // remain transfer read operation to do the broadcast fusion - if (isa(op) and !disableBroadcastOp) { + if (isa(op) and not disableBroadcastOp) { auto srcOp = op->getOperand(0).getDefiningOp(); assert(isa(srcOp)); // only have write operation, otherwise the group size will bigger // than 1. Because the last operation is always a write operation in // each group - if (tmpQueue.size() <= 1) - continue; - getFusionStrategy().getOpAnchorPos()[srcOp] = getFusionStrategy().getOpAnchorPos()[op]; + rewriter.replaceOp(op, srcOp); continue; } @@ -3763,21 +3742,6 @@ void VectorOperationAnalyzer::specialOperationRectify( getFusionStrategy().getOpAnchorPos()[accSourceOp] = getOperationVectorType(accSourceOp)->getRank() - 1; } - // // case: - // // %1 = some op - // // %2 = tensor.empty() - // // %3 = vector.transfer_write %1, %2 - // // -> move emtpy operation before %1 for better generate %1 - // if (isa(op)) { - // auto srcOp = op->getOperand(1).getDefiningOp(); - // if (isa_and_nonnull(srcOp)) { - // Operation *writeVectorOp = op->getOperands()[0].getDefiningOp(); - // if (visitedOperation[srcOp] >= visitedOperation[writeVectorOp]) { - // srcOp->moveBefore(writeVectorOp); - // visitedOperation[srcOp] = visitedOperation[writeVectorOp]; - // } - // } - // } newQueue.push(op); } getFusionStrategy().getOpGroups()[idx] = newQueue; diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index e1bcaf292..1eee9b112 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -53,13 +53,19 @@ namespace mlir { namespace gc { namespace { +//===----------------------------------------------------------------------===// +// helper function +//===----------------------------------------------------------------------===// + +/// build a constant operation of index type Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); +/// set correct operand for the operation void setOperationCorrectOperand( - Operation *op, const ValueRange &iterArgs, - const llvm::DenseMap &operandIdxMap, + Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, DenseMap &originalOperandLoopArgsMap, ArrayRef inductionVars, - const llvm::DenseMap &opPermuationMap); + DenseMap &opPermuationMap); +/// get operation read or write tensor mlir::FailureOr getOperationOperateTensor(Operation *op); struct HardWareInfo { @@ -67,16 +73,21 @@ struct HardWareInfo { bool favx2 = true; }; -/// VectorType conversion helper class +/// Vector type conversion helper class class TypeHelper { private: HardWareInfo HWInfo; public: + /// use \param info to set hardware information void setHardWareInfo(HardWareInfo &info) { HWInfo = info; } + /// get vector \param type max loop step according to hardware information int getDataTypeValidSteps(VectorType type); + /// get vector \param type an even for loop step int generateValidSteps(int steps, VectorType type); + /// get vector \param type max simd length according to hardware information int getDataTypeMAXSIMDLength(VectorType type); + /// get operation's vector type VectorType getVectorzedType(Operation *op, uint32_t loopStep = 0); }; @@ -361,7 +372,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { for (size_t i = 0; i < opGroups.size(); i++) { groupOpResults.emplace_back( llvm::MapVector>()); - groupOpInitArgs.emplace_back(llvm::SetVector()); + groupOpInitArgs.emplace_back(SetVector()); } } } @@ -614,13 +625,13 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { llvm::SmallVector &inductionVars, bool lastDimReduction, MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel); - /// transpose operation related + /// generate for loop for transpose operation scf::ForOp generateTransposeForLoop(const size_t groupId); scf::ForOp generateTransposeForLoopWithLastDim( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const int tpSteps, const Location &loc, SmallVector &inductionVars, - const ValueRange &iterArgs, DenseMap &operandIdxMap, - DenseMap &originalOperandMap); + ValueRange iterArgs, DenseMap &operandIdxMap, + DenseMap &originalOperandMap, Operation *successorWriteOp); scf::ForOp generateTransposeScalarDataMovement( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, diff --git a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir index 7a2eac7a8..2b96d2fec 100644 --- a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir +++ b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir @@ -1,5 +1,18 @@ // RUN: gc-opt %s --split-input-file --fold-tensor-operation --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all | FileCheck %s + +// CHECK-DAG: #[[map0:.*]] = affine_map<()[s0, s1] -> (s0 * 64 + s1)> +// CHECK-DAG: #[[map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[map2:.*]] = affine_map<(d0) -> (d0 * 4)> +// CHECK-DAG: #[[map3:.*]] = affine_map<(d0) -> (d0 * 128)> +// CHECK-DAG: #[[map4:.*]] = affine_map<(d0) -> (d0 * 64)> +// CHECK-DAG: #[[map5:.*]] = affine_map<(d0, d1) -> (d0 floordiv 32 + d1 floordiv 32)> +// CHECK-DAG: #[[map6:.*]] = affine_map<(d0, d1) -> (d0 floordiv 16 + d1 floordiv 16)> +// CHECK-DAG: #[[map7:.*]] = affine_map<()[s0, s1] -> (s0 * 32 + s1)> +// CHECK-DAG: #[[map8:.*]] = affine_map<()[s0, s1] -> (s0 * 16 + s1)> + + + // CHECK-LABEL: func @add_tensor_test0 // CHECK: %[[C4096:.*]] = arith.constant 4096 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index @@ -15,265 +28,413 @@ // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> // CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[READ0]] : vector<16xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<11008x4096xf32> -func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { +// func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { - %0 = tensor.empty() : tensor<11008x4096xf32> - %1 = linalg.add ins(%arg0, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> - %2 = linalg.add ins(%1, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> - return %2 : tensor<11008x4096xf32> -} +// %0 = tensor.empty() : tensor<11008x4096xf32> +// %1 = linalg.add ins(%arg0, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> +// %2 = linalg.add ins(%1, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> +// return %2 : tensor<11008x4096xf32> +// } -// CHECK-LABEL: func @reduce_keepdimtest1 -// CHECK: %[[C32:.*]] = arith.constant 32 : index -// CHECK: %[[C64:.*]] = arith.constant 64 : index -// CHECK: %[[C16:.*]] = arith.constant 16 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x64xf32> -// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<16x1x64xf32> -// CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> -// CHECK: scf.for -// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], {{.*}} : vector<16xf32> -// CHECK: scf.yield -// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> -// CHECK: scf.yield -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> -func.func @reduce_keepdimtest1(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { - %0 = tensor.empty() : tensor<16x64xf32> - %reduce = linalg.reduce - ins(%arg0:tensor<16x32x64xf32>) - outs(%0:tensor<16x64xf32>) - dimensions = [1] - (%in: f32, %out: f32) { - %1 = arith.addf %out, %in: f32 - linalg.yield %1: f32 - } - %2 = tensor.expand_shape %reduce [[0],[1, 2]] output_shape [16, 1, 64] : tensor<16x64xf32> into tensor<16x1x64xf32> - return %2 : tensor<16x1x64xf32> -} +// // CHECK-LABEL: func @reduce_keepdimtest1 +// // CHECK: %[[C32:.*]] = arith.constant 32 : index +// // CHECK: %[[C64:.*]] = arith.constant 64 : index +// // CHECK: %[[C16:.*]] = arith.constant 16 : index +// // CHECK: %[[C1:.*]] = arith.constant 1 : index +// // CHECK: %[[C0:.*]] = arith.constant 0 : index +// // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x64xf32> +// // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<16x1x64xf32> +// // CHECK: scf.for +// // CHECK: scf.for +// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// // CHECK: scf.for +// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], {{.*}} : vector<16xf32> +// // CHECK: scf.yield +// // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> +// // CHECK: scf.yield +// // CHECK: scf.for +// // CHECK: scf.for +// // CHECK: scf.for +// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// func.func @reduce_keepdimtest1(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { +// %0 = tensor.empty() : tensor<16x64xf32> +// %reduce = linalg.reduce +// ins(%arg0:tensor<16x32x64xf32>) +// outs(%0:tensor<16x64xf32>) +// dimensions = [1] +// (%in: f32, %out: f32) { +// %1 = arith.addf %out, %in: f32 +// linalg.yield %1: f32 +// } +// %2 = tensor.expand_shape %reduce [[0],[1, 2]] output_shape [16, 1, 64] : tensor<16x64xf32> into tensor<16x1x64xf32> +// return %2 : tensor<16x1x64xf32> +// } -// CHECK-LABEL: func @fc_relu_test2 - // CHECK: %[[MATMUL:.*]] = linalg.matmul - // CHECK: scf.for - // CHECK: scf.for - // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> - // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> - // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> - // CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], {{.*}} : vector<16xf32> - // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512xf32> -func.func @fc_relu_test2(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, - %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) - -> tensor<512x512xf32> { +// // CHECK-LABEL: func @fc_relu_test2 +// // CHECK: %[[MATMUL:.*]] = linalg.matmul +// // CHECK: scf.for +// // CHECK: scf.for +// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> +// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> +// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// // CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], {{.*}} : vector<16xf32> +// // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512xf32> +// func.func @fc_relu_test2(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, +// %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) +// -> tensor<512x512xf32> { - // Matrix-matrix multiplication. - %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) - outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> +// // Matrix-matrix multiplication. +// %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) +// outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> - %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } - ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) - outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> +// %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } +// ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) +// outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> - %c0f = arith.constant 0.0 : f32 - %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } - ins(%biased, %c0f : tensor<512x512xf32>, f32) - outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> - func.return %relued : tensor<512x512xf32> -} +// %c0f = arith.constant 0.0 : f32 +// %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } +// ins(%biased, %c0f : tensor<512x512xf32>, f32) +// outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> +// func.return %relued : tensor<512x512xf32> +// } -// CHECK-LABEL: func @matmul_add_test3 -// CHECK: %[[MATMUL0:.*]] = linalg.matmul -// CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> -// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> -// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32xf32> -// CHECK: scf.yield -// CHECK: scf.yield -func.func @matmul_add_test3(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { +// // CHECK-LABEL: func @matmul_add_test3 +// // CHECK: %[[MATMUL0:.*]] = linalg.matmul +// // CHECK: scf.for +// // CHECK: scf.for +// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> +// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> +// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32xf32> +// // CHECK: scf.yield +// // CHECK: scf.yield +// func.func @matmul_add_test3(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { - %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> - %1 = tensor.empty() : tensor<8192x16384xf32> - %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %4 = arith.addf %in, %in_0 : f32 - linalg.yield %4 : f32 - } -> tensor<8192x16384xf32> - %c0 = arith.constant 0 : index - %c8192 = arith.constant 8192 : index - %c128 = arith.constant 128 : index - %3 = scf.for %arg4 = %c0 to %c8192 step %c128 iter_args(%arg5 = %1) -> (tensor<8192x16384xf32>) { - %c0_0 = arith.constant 0 : index - %c16384 = arith.constant 16384 : index - %c128_1 = arith.constant 128 : index - %4 = scf.for %arg6 = %c0_0 to %c16384 step %c128_1 iter_args(%arg7 = %arg5) -> (tensor<8192x16384xf32>) { - %extracted_slice = tensor.extract_slice %arg0[%arg4, 0] [128, 12288] [1, 1] : tensor<8192x12288xf16> to tensor<128x12288xf16> - %extracted_slice_2 = tensor.extract_slice %arg1[0, %arg6] [12288, 128] [1, 1] : tensor<12288x16384xf16> to tensor<12288x128xf16> - %extracted_slice_3 = tensor.extract_slice %arg2[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> - %5 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice, %extracted_slice_2 : tensor<128x12288xf16>, tensor<12288x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf32>) -> tensor<128x128xf32> - %extracted_slice_4 = tensor.extract_slice %0[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> - %extracted_slice_5 = tensor.extract_slice %arg3[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> - %extracted_slice_6 = tensor.extract_slice %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> - %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_5 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%extracted_slice_6 : tensor<128x128xf32>) attrs = {__root_op__ = 0 : i64} { - ^bb0(%in: f32, %in_9: f32, %out: f32): - %8 = arith.addf %in, %in_9 : f32 - linalg.yield %8 : f32 - } -> tensor<128x128xf32> - %c0_7 = arith.constant 0 : index - %c128_8 = arith.constant 128 : index - %c32 = arith.constant 32 : index - %7 = scf.for %arg8 = %c0_7 to %c128_8 step %c32 iter_args(%arg9 = %extracted_slice_6) -> (tensor<128x128xf32>) { - %c0_9 = arith.constant 0 : index - %c128_10 = arith.constant 128 : index - %c32_11 = arith.constant 32 : index - %8 = scf.for %arg10 = %c0_9 to %c128_10 step %c32_11 iter_args(%arg11 = %arg9) -> (tensor<128x128xf32>) { - %extracted_slice_12 = tensor.extract_slice %extracted_slice[%arg8, 0] [32, 12288] [1, 1] : tensor<128x12288xf16> to tensor<32x12288xf16> - %extracted_slice_13 = tensor.extract_slice %extracted_slice_2[0, %arg10] [12288, 32] [1, 1] : tensor<12288x128xf16> to tensor<12288x32xf16> - %extracted_slice_14 = tensor.extract_slice %extracted_slice_3[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> - %9 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice_12, %extracted_slice_13 : tensor<32x12288xf16>, tensor<12288x32xf16>) outs(%extracted_slice_14 : tensor<32x32xf32>) -> tensor<32x32xf32> - %extracted_slice_15 = tensor.extract_slice %5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> - %extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> - %extracted_slice_17 = tensor.extract_slice %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> - %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %extracted_slice_16 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice_17 : tensor<32x32xf32>) attrs = {__root_op__ = 0 : i64} { - ^bb0(%in: f32, %in_19: f32, %out: f32): - %11 = arith.addf %in, %in_19 : f32 - linalg.yield %11 : f32 - } -> tensor<32x32xf32> - %inserted_slice_18 = tensor.insert_slice %10 into %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x128xf32> - scf.yield %inserted_slice_18 : tensor<128x128xf32> - } {__parallel_loop__ = 1 : i64} - scf.yield %8 : tensor<128x128xf32> - } {__parallel_loop__ = 1 : i64} - %inserted_slice = tensor.insert_slice %7 into %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<8192x16384xf32> - scf.yield %inserted_slice : tensor<8192x16384xf32> - } {__parallel_loop__ = 0 : i64} - scf.yield %4 : tensor<8192x16384xf32> - } {__parallel_loop__ = 0 : i64} - return %3 : tensor<8192x16384xf32> -} +// %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> +// %1 = tensor.empty() : tensor<8192x16384xf32> +// %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { +// ^bb0(%in: f32, %in_0: f32, %out: f32): +// %4 = arith.addf %in, %in_0 : f32 +// linalg.yield %4 : f32 +// } -> tensor<8192x16384xf32> +// %c0 = arith.constant 0 : index +// %c8192 = arith.constant 8192 : index +// %c128 = arith.constant 128 : index +// %3 = scf.for %arg4 = %c0 to %c8192 step %c128 iter_args(%arg5 = %1) -> (tensor<8192x16384xf32>) { +// %c0_0 = arith.constant 0 : index +// %c16384 = arith.constant 16384 : index +// %c128_1 = arith.constant 128 : index +// %4 = scf.for %arg6 = %c0_0 to %c16384 step %c128_1 iter_args(%arg7 = %arg5) -> (tensor<8192x16384xf32>) { +// %extracted_slice = tensor.extract_slice %arg0[%arg4, 0] [128, 12288] [1, 1] : tensor<8192x12288xf16> to tensor<128x12288xf16> +// %extracted_slice_2 = tensor.extract_slice %arg1[0, %arg6] [12288, 128] [1, 1] : tensor<12288x16384xf16> to tensor<12288x128xf16> +// %extracted_slice_3 = tensor.extract_slice %arg2[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> +// %5 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice, %extracted_slice_2 : tensor<128x12288xf16>, tensor<12288x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf32>) -> tensor<128x128xf32> +// %extracted_slice_4 = tensor.extract_slice %0[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> +// %extracted_slice_5 = tensor.extract_slice %arg3[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> +// %extracted_slice_6 = tensor.extract_slice %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> +// %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_5 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%extracted_slice_6 : tensor<128x128xf32>) attrs = {__root_op__ = 0 : i64} { +// ^bb0(%in: f32, %in_9: f32, %out: f32): +// %8 = arith.addf %in, %in_9 : f32 +// linalg.yield %8 : f32 +// } -> tensor<128x128xf32> +// %c0_7 = arith.constant 0 : index +// %c128_8 = arith.constant 128 : index +// %c32 = arith.constant 32 : index +// %7 = scf.for %arg8 = %c0_7 to %c128_8 step %c32 iter_args(%arg9 = %extracted_slice_6) -> (tensor<128x128xf32>) { +// %c0_9 = arith.constant 0 : index +// %c128_10 = arith.constant 128 : index +// %c32_11 = arith.constant 32 : index +// %8 = scf.for %arg10 = %c0_9 to %c128_10 step %c32_11 iter_args(%arg11 = %arg9) -> (tensor<128x128xf32>) { +// %extracted_slice_12 = tensor.extract_slice %extracted_slice[%arg8, 0] [32, 12288] [1, 1] : tensor<128x12288xf16> to tensor<32x12288xf16> +// %extracted_slice_13 = tensor.extract_slice %extracted_slice_2[0, %arg10] [12288, 32] [1, 1] : tensor<12288x128xf16> to tensor<12288x32xf16> +// %extracted_slice_14 = tensor.extract_slice %extracted_slice_3[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> +// %9 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice_12, %extracted_slice_13 : tensor<32x12288xf16>, tensor<12288x32xf16>) outs(%extracted_slice_14 : tensor<32x32xf32>) -> tensor<32x32xf32> +// %extracted_slice_15 = tensor.extract_slice %5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> +// %extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> +// %extracted_slice_17 = tensor.extract_slice %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> +// %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %extracted_slice_16 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice_17 : tensor<32x32xf32>) attrs = {__root_op__ = 0 : i64} { +// ^bb0(%in: f32, %in_19: f32, %out: f32): +// %11 = arith.addf %in, %in_19 : f32 +// linalg.yield %11 : f32 +// } -> tensor<32x32xf32> +// %inserted_slice_18 = tensor.insert_slice %10 into %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x128xf32> +// scf.yield %inserted_slice_18 : tensor<128x128xf32> +// } {__parallel_loop__ = 1 : i64} +// scf.yield %8 : tensor<128x128xf32> +// } {__parallel_loop__ = 1 : i64} +// %inserted_slice = tensor.insert_slice %7 into %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<8192x16384xf32> +// scf.yield %inserted_slice : tensor<8192x16384xf32> +// } {__parallel_loop__ = 0 : i64} +// scf.yield %4 : tensor<8192x16384xf32> +// } {__parallel_loop__ = 0 : i64} +// return %3 : tensor<8192x16384xf32> +// } + +// // CHECK-LABEL: func @fuse_mlp_test4 +// #map = affine_map<(d0) -> (d0 * 64)> +// #map1 = affine_map<(d0) -> (d0 * 128)> +// #map2 = affine_map<(d0) -> (d0 * 4)> +// #map3 = affine_map<(d0) -> (d0 floordiv 16)> +// #map4 = affine_map<(d0) -> (d0 floordiv 32)> +// func.func @fuse_mlp_test4(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { +// %c32 = arith.constant 32 : index +// %c512 = arith.constant 512 : index +// %c128 = arith.constant 128 : index +// %c64 = arith.constant 64 : index +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : bf16 +// %0 = tensor.empty() : tensor<128x256xbf16> +// %1 = tensor.empty() : tensor<512x256xbf16> +// %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { +// %3 = affine.apply #map(%arg3) +// %4 = affine.apply #map1(%arg4) +// %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> +// %5 = affine.apply #map2(%arg4) +// %extracted_slice_0 = tensor.extract_slice %arg1[0, %5, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x4x16x32xbf16> +// %extracted_slice_1 = tensor.extract_slice %1[0, %4] [512, 128] [1, 1] : tensor<512x256xbf16> to tensor<512x128xbf16> +// %extracted_slice_2 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> +// %extracted_slice_3 = tensor.extract_slice %arg2[%4] [128] [1] : tensor<256xbf16> to tensor<128xbf16> +// %extracted_slice_4 = tensor.extract_slice %0[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> +// %extracted_slice_5 = tensor.extract_slice %arg6[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> +// %extracted_slice_6 = tensor.extract_slice %arg7[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> +// %6:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_2, %arg10 = %extracted_slice_5, %arg11 = %extracted_slice_6) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %7:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %8:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %extracted_slice_7 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> +// %9 = affine.apply #map3(%arg16) +// %10 = affine.apply #map4(%arg12) +// %extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%9, %10, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x4x16x32xbf16> +// %extracted_slice_9 = tensor.extract_slice %extracted_slice_1[%arg16, %arg12] [512, 128] [1, 1] : tensor<512x128xbf16> to tensor<512x128xbf16> +// %extracted_slice_10 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> +// %extracted_slice_11 = tensor.extract_slice %extracted_slice_3[%arg12] [128] [1] : tensor<128xbf16> to tensor<128xbf16> +// %extracted_slice_12 = tensor.extract_slice %extracted_slice_4[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> +// %extracted_slice_13 = tensor.extract_slice %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> +// %extracted_slice_14 = tensor.extract_slice %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> +// %11:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_10, %arg22 = %extracted_slice_13, %arg23 = %extracted_slice_14) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %12:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %13:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { +// %extracted_slice_17 = tensor.extract_slice %extracted_slice_7[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> +// %14 = affine.apply #map3(%arg28) +// %15 = affine.apply #map4(%arg24) +// %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%14, %15, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> +// %extracted_slice_19 = tensor.extract_slice %extracted_slice_9[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> +// %unpack = tensor.unpack %extracted_slice_18 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_19 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> +// // CHECK: %[[C16:.*]] = arith.constant 16 : index +// // CHECK: %[[C1:.*]] = arith.constant 1 : index +// // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : bf16 +// // CHECK: %[[C0:.*]] = arith.constant 0 : index +// // CHECK: %[[C64:.*]] = arith.constant 64 : index +// // CHECK: %[[C128:.*]] = arith.constant 128 : index +// // CHECK: %[[C512:.*]] = arith.constant 512 : index +// // CHECK: %[[C32:.*]] = arith.constant 32 : index +// // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<128x256xbf16> +// // CHECK: scf.forall +// // CHECK-COUNT-6: scf.for +// // CHECK-COUNT-4: scf.for +// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x1x16x32xbf16>, vector<1xbf16> +// // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xbf16>, tensor<32x16x1x32xbf16> +// // CHECK: %[[FILL0:.*]] = linalg.fill +// // CHECK-COUNT-3: scf.for +// // CHECK: %[[APPLY0:.*]] = affine.apply +// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x512xbf16>, vector<32xbf16> +// // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x32x512xbf16> +// // CHECK-COUNT-4: scf.for +// // CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x16x1x32xbf16>, vector<32xbf16> +// // CHECK: %[[APPLY1:.*]] = affine.apply +// // CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x512x32xbf16> +// // CHECK: %[[MATMUL0:.*]] = linalg.batch_reduce_matmul +// // CHECK-COUNT-2: scf.for +// // CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xbf16>, vector<32xbf16> +// // CHECK: %[[READ4:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32xbf16>, vector<32xbf16> +// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ3]], %[[READ4]] : vector<32xbf16> +// // CHECK: %[[EXP0:.*]] = math.exp %[[ADD0]] : vector<32xbf16> +// // CHECK: %[[WRITE3:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> +// // CHECK: %[[WRITE4:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> +// %extracted_slice_20 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %16 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %expanded = tensor.expand_shape %extracted_slice_17 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> +// %expanded_21 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> +// %17 = linalg.batch_reduce_matmul ins(%expanded, %expanded_21 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%16 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %extracted_slice_22 = tensor.extract_slice %extracted_slice_11[%arg24] [32] [1] : tensor<128xbf16> to tensor<32xbf16> +// %extracted_slice_23 = tensor.extract_slice %extracted_slice_12[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %broadcasted = linalg.broadcast ins(%extracted_slice_22 : tensor<32xbf16>) outs(%extracted_slice_23 : tensor<32x32xbf16>) dimensions = [0] +// %extracted_slice_24 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %18 = linalg.add ins(%17, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_24 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %inserted_slice_25 = tensor.insert_slice %17 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// %extracted_slice_26 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> +// %19 = linalg.exp ins(%18 : tensor<32x32xbf16>) outs(%extracted_slice_26 : tensor<32x32xbf16>) -> tensor<32x32xbf16> +// %inserted_slice_27 = tensor.insert_slice %18 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// %inserted_slice_28 = tensor.insert_slice %19 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> +// scf.yield %inserted_slice_25, %inserted_slice_27, %inserted_slice_28 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %13#0, %13#1, %13#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %12#0, %12#1, %12#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// %inserted_slice = tensor.insert_slice %11#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// %inserted_slice_15 = tensor.insert_slice %11#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// %inserted_slice_16 = tensor.insert_slice %11#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> +// scf.yield %inserted_slice, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %8#0, %8#1, %8#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> +// } +// scf.forall.in_parallel { +// tensor.parallel_insert_slice %6#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// tensor.parallel_insert_slice %6#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// tensor.parallel_insert_slice %6#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> +// } +// } +// return %2#2 : tensor<128x256xbf16> +// } + +// // CHECK-LABEL: func @elem_pack_transpose_inner_dims_test5 +// // CHECK: %[[C32:.*]] = arith.constant 32 : index +// // CHECK: %[[C4:.*]] = arith.constant 4 : index +// // CHECK: %[[C256:.*]] = arith.constant 256 : index +// // CHECK: %[[C16:.*]] = arith.constant 16 : index +// // CHECK: %[[C128:.*]] = arith.constant 128 : index +// // CHECK: %[[C1:.*]] = arith.constant 1 : index +// // CHECK: %[[C0:.*]] = arith.constant 0 : index +// // CHECK: %[[C32I32:.*]] = arith.constant 0 : i32 +// // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x32x16x16xi32> +// // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<128x256xi32> +// // CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<4x16x16x32xi32> +// // CHECK: scf.for +// // CHECK: scf.for +// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// // CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> +// // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<128x256xi32> +// // CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<4x32x16x16xi32>) +// // CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<4x32x16x16xi32>) +// // CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<4x32x16x16xi32>) +// // CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<4x32x16x16xi32>) +// // CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg2]], %[[arg4]]] +// // CHECK: %[[APPLY1:.*]] = affine.apply #[[map8]]()[%[[arg6]], %[[arg8]]] +// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<4x32x16x16xi32> +// // CHECK-COUNT-4: scf.for +// // CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> +// // CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<4x16x16x32xi32> +// #map5 = affine_map<(d0, d1) -> (d0, d1)> +// func.func @elem_pack_transpose_inner_dims_test5(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ +// %init = tensor.empty() : tensor<128x256xi32> +// %elem = linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel"]} +// ins(%arg0 : tensor<128x256xi32>) +// outs(%init : tensor<128x256xi32>) { +// ^bb0(%arg3: i32, %arg4: i32): +// %4 = arith.addi %arg3, %arg3 : i32 +// linalg.yield %4 : i32 +// } -> tensor<128x256xi32> +// %pack = tensor.pack %elem +// inner_dims_pos = [1, 0] +// inner_tiles = [16, 32] +// into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> +// return %pack : tensor<4x16x16x32xi32> +// } -// CHECK-LABEL: func @fuse_mlp_test4 -#map = affine_map<(d0) -> (d0 * 64)> -#map1 = affine_map<(d0) -> (d0 * 128)> -#map2 = affine_map<(d0) -> (d0 * 4)> -#map3 = affine_map<(d0) -> (d0 floordiv 16)> -#map4 = affine_map<(d0) -> (d0 floordiv 32)> -func.func @fuse_mlp_test4(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { - %c32 = arith.constant 32 : index - %c512 = arith.constant 512 : index - %c128 = arith.constant 128 : index - %c64 = arith.constant 64 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x256xbf16> - %1 = tensor.empty() : tensor<512x256xbf16> - %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { - %3 = affine.apply #map(%arg3) - %4 = affine.apply #map1(%arg4) - %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> - %5 = affine.apply #map2(%arg4) - %extracted_slice_0 = tensor.extract_slice %arg1[0, %5, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x4x16x32xbf16> - %extracted_slice_1 = tensor.extract_slice %1[0, %4] [512, 128] [1, 1] : tensor<512x256xbf16> to tensor<512x128xbf16> - %extracted_slice_2 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %extracted_slice_3 = tensor.extract_slice %arg2[%4] [128] [1] : tensor<256xbf16> to tensor<128xbf16> - %extracted_slice_4 = tensor.extract_slice %0[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %extracted_slice_5 = tensor.extract_slice %arg6[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %extracted_slice_6 = tensor.extract_slice %arg7[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> - %6:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_2, %arg10 = %extracted_slice_5, %arg11 = %extracted_slice_6) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %7:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %8:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %extracted_slice_7 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> - %9 = affine.apply #map3(%arg16) - %10 = affine.apply #map4(%arg12) - %extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%9, %10, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x4x16x32xbf16> - %extracted_slice_9 = tensor.extract_slice %extracted_slice_1[%arg16, %arg12] [512, 128] [1, 1] : tensor<512x128xbf16> to tensor<512x128xbf16> - %extracted_slice_10 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> - %extracted_slice_11 = tensor.extract_slice %extracted_slice_3[%arg12] [128] [1] : tensor<128xbf16> to tensor<128xbf16> - %extracted_slice_12 = tensor.extract_slice %extracted_slice_4[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> - %extracted_slice_13 = tensor.extract_slice %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> - %extracted_slice_14 = tensor.extract_slice %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> - %11:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_10, %arg22 = %extracted_slice_13, %arg23 = %extracted_slice_14) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %12:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %13:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { - %extracted_slice_17 = tensor.extract_slice %extracted_slice_7[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> - %14 = affine.apply #map3(%arg28) - %15 = affine.apply #map4(%arg24) - %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%14, %15, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> - %extracted_slice_19 = tensor.extract_slice %extracted_slice_9[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> - %unpack = tensor.unpack %extracted_slice_18 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_19 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> +// #map6 = affine_map<(d0, d1) -> (d0, d1)> +// func.func @elem_pack_transpose_outer_dims_test6(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{ +// %init = tensor.empty() : tensor<128x256xi32> +// %elem = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} +// ins(%arg0 : tensor<128x256xi32>) +// outs(%init : tensor<128x256xi32>) { +// ^bb0(%arg3: i32, %arg4: i32): +// %4 = arith.addi %arg3, %arg3 : i32 +// linalg.yield %4 : i32 +// } -> tensor<128x256xi32> +// %pack = tensor.pack %elem +// outer_dims_perm = [1, 0] +// inner_dims_pos = [0, 1] +// inner_tiles = [32, 16] +// into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> +// return %pack : tensor<16x4x32x16xi32> +// } + + +// #map7 = affine_map<(d0, d1) -> (d0, d1)> +// func.func @elem_pack_transpose_inner_and_outer_dims_test7(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{ +// %init = tensor.empty() : tensor<128x256xi32> +// %elem = linalg.generic {indexing_maps = [#map7, #map7], iterator_types = ["parallel", "parallel"]} +// ins(%arg0 : tensor<128x256xi32>) +// outs(%init : tensor<128x256xi32>) { +// ^bb0(%arg3: i32, %arg4: i32): +// %4 = arith.addi %arg3, %arg3 : i32 +// linalg.yield %4 : i32 +// } -> tensor<128x256xi32> +// %pack = tensor.pack %elem +// outer_dims_perm = [1, 0] +// inner_dims_pos = [1, 0] +// inner_tiles = [16, 32] +// into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> +// return %pack : tensor<16x4x16x32xi32> +// } + +// CHECK-LABEL: func @elem_pack_transpose_inner_and_outer_dims2_test8 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C57:.*]] = arith.constant 57 : index +// CHECK: %[[C56:.*]] = arith.constant 56 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : bf16 // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C64:.*]] = arith.constant 64 : index -// CHECK: %[[C128:.*]] = arith.constant 128 : index -// CHECK: %[[C512:.*]] = arith.constant 512 : index -// CHECK: %[[C32:.*]] = arith.constant 32 : index -// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: scf.forall -// CHECK-COUNT-6: scf.for -// CHECK-COUNT-4: scf.for -// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x1x16x32xbf16>, vector<1xbf16> -// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xbf16>, tensor<32x16x1x32xbf16> -// CHECK: %[[FILL0:.*]] = linalg.fill -// CHECK-COUNT-3: scf.for -// CHECK: %[[APPLY0:.*]] = affine.apply -// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x512xbf16>, vector<32xbf16> -// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x32x512xbf16> +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<1x56x57x2x32xf32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<1x56x57x64xf32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x2x56x57x32xf32> +// CHECK: scf.for +// CHECK: +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<128x256xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<4x32x16x16xi32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg2]], %[[arg4]]] +// CHECK: %[[APPLY1:.*]] = affine.apply #[[map8]]()[%[[arg6]], %[[arg8]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<4x32x16x16xi32> // CHECK-COUNT-4: scf.for -// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x16x1x32xbf16>, vector<32xbf16> -// CHECK: %[[APPLY1:.*]] = affine.apply -// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x512x32xbf16> -// CHECK: %[[MATMUL0:.*]] = linalg.batch_reduce_matmul -// CHECK-COUNT-2: scf.for -// CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xbf16>, vector<32xbf16> -// CHECK: %[[READ4:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32xbf16>, vector<32xbf16> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ3]], %[[READ4]] : vector<32xbf16> -// CHECK: %[[EXP0:.*]] = math.exp %[[ADD0]] : vector<32xbf16> -// CHECK: %[[WRITE3:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> -// CHECK: %[[WRITE4:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> - %extracted_slice_20 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %16 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %expanded = tensor.expand_shape %extracted_slice_17 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> - %expanded_21 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> - %17 = linalg.batch_reduce_matmul ins(%expanded, %expanded_21 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%16 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %extracted_slice_22 = tensor.extract_slice %extracted_slice_11[%arg24] [32] [1] : tensor<128xbf16> to tensor<32xbf16> - %extracted_slice_23 = tensor.extract_slice %extracted_slice_12[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %broadcasted = linalg.broadcast ins(%extracted_slice_22 : tensor<32xbf16>) outs(%extracted_slice_23 : tensor<32x32xbf16>) dimensions = [0] - %extracted_slice_24 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %18 = linalg.add ins(%17, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_24 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_25 = tensor.insert_slice %17 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %extracted_slice_26 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> - %19 = linalg.exp ins(%18 : tensor<32x32xbf16>) outs(%extracted_slice_26 : tensor<32x32xbf16>) -> tensor<32x32xbf16> - %inserted_slice_27 = tensor.insert_slice %18 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - %inserted_slice_28 = tensor.insert_slice %19 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> - scf.yield %inserted_slice_25, %inserted_slice_27, %inserted_slice_28 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %13#0, %13#1, %13#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %12#0, %12#1, %12#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - %inserted_slice = tensor.insert_slice %11#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_15 = tensor.insert_slice %11#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - %inserted_slice_16 = tensor.insert_slice %11#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> - scf.yield %inserted_slice, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %8#0, %8#1, %8#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> - } - scf.forall.in_parallel { - tensor.parallel_insert_slice %6#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %6#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - tensor.parallel_insert_slice %6#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> - } - } - return %2#2 : tensor<128x256xbf16> - } +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<4x16x16x32xi32> +#map8 = affine_map<(d0, d1, d2, d3) -> (d3)> +#map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @elem_pack_transpose_inner_and_outer_dims2_test8(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> { + %0 = tensor.empty() : tensor<1x56x57x64xf32> + %1 = linalg.generic { + indexing_maps = [#map8, #map9], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<64xf32>) + outs(%0 : tensor<1x56x57x64xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x56x57x64xf32> + %2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32> + return %2 : tensor<1x2x56x57x32xf32> +} + + +// // CHECK-LABEL: func @broadcast_same_shape_test9 +// // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// // CHECK: %[[C0:.*]] = arith.constant 0 : index +// // CHECK: %[[C1:.*]] = arith.constant 1 : index +// // CHECK: %[[C2:.*]] = arith.constant 2 : index +// // CHECK: %[[C16:.*]] = arith.constant 16 : index +// // CHECK: scf.for +// // CHECK: scf.for +// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<2x16xf32>, vector<16xf32> +// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16xf32>, vector<16xf32> +// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<2x16xf32> +// func.func @broadcast_same_shape_test9(%input: tensor<16xf32>, %init: tensor<2x16xf32>) -> tensor<2x16xf32> { +// %empty = tensor.empty() : tensor<2x16xf32> +// %0 = linalg.broadcast ins(%input: tensor<16xf32>) outs(%empty: tensor<2x16xf32>) dimensions = [0] +// %1 = linalg.add ins(%0, %init : tensor<2x16xf32>, tensor<2x16xf32>) outs(%init : tensor<2x16xf32>) -> tensor<2x16xf32> +// return %1 : tensor<2x16xf32> +// } From 284a97b94104c12cd33ed88c233364899838e66c Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 3 Sep 2024 11:04:44 +0800 Subject: [PATCH 35/66] add test --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 79 +- lib/gc/Transforms/TilingVector.h | 22 +- .../gc/Transforms/cpu-vetor-distribution.mlir | 886 +++++++++++------- 3 files changed, 566 insertions(+), 421 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 25dc2a436..e12e6fea6 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -32,6 +32,7 @@ namespace { linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::ExtractSliceOp, \ tensor::InsertSliceOp +// , microkernel::BrgemmOp /// TODO: remove it in the future bool disableSpecialOp = false; @@ -1808,19 +1809,18 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( }); } -ValueRange ForLoopGenerator::prepareForLoopArgs( +void ForLoopGenerator::prepareForLoopArgs( const size_t grpIdx, DenseMap ¤tLoopStateIdxMap, DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap) { + DenseMap &loopArgsOriginalOperandMap, + SmallVector &loopArgs) { SetVector &grpArgs = getGroupOpInitArgs()[grpIdx]; - SmallVector forLoopArgs(grpArgs.begin(), grpArgs.end()); - ValueRange initArgs(forLoopArgs); - for (auto [idx, val] : llvm::enumerate(initArgs)) { + loopArgs.assign(grpArgs.begin(), grpArgs.end()); + for (auto [idx, val] : llvm::enumerate(grpArgs)) { currentLoopStateIdxMap[val] = idx; originalOperandLoopArgsMap[val] = val; loopArgsOriginalOperandMap[val] = val; } - return initArgs; } void ForLoopGenerator::rearrageMultiReductionIR( @@ -1920,9 +1920,9 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { DenseMap originalOperandLoopArgsMap, loopArgsOriginalOperandMap, forResultOrignalResultMap; - ValueRange initArgs = prepareForLoopArgs(grpIdx, currentLoopStateIdxMap, - originalOperandLoopArgsMap, - loopArgsOriginalOperandMap); + SmallVector initArgs; + prepareForLoopArgs(grpIdx, currentLoopStateIdxMap, originalOperandLoopArgsMap, + loopArgsOriginalOperandMap, initArgs); SmallVector inductionVars; OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); @@ -2294,8 +2294,9 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { DenseMap operandIdxMap; DenseMap originalOperandMap, operandOriginalMap, resultIdxMap, forResultOrignalResultMap; - SmallVector iterArgs = prepareForLoopArgs( - grpIdx, operandIdxMap, originalOperandMap, operandOriginalMap); + SmallVector iterArgs; + prepareForLoopArgs(grpIdx, operandIdxMap, originalOperandMap, + operandOriginalMap, iterArgs); SmallVector inductionVars; // TODO: need to process transpose on all one dim @@ -2888,28 +2889,6 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( return forOp; } -bool isSameVectorType(Operation *op1, Operation *op2) { - auto type1 = getOperationVectorType(op1); - auto type2 = getOperationVectorType(op2); - if (failed(type1) || failed(type2)) { - return false; - } - auto sp1 = type1.value(); - auto sp2 = type2.value(); - if (sp1.getRank() != sp2.getRank()) { - return false; - } - bool isSame = true; - // from front to back - for (long i = 0; i < sp1.getRank(); i++) { - if (sp1.getDimSize(i) != sp2.getDimSize(i)) { - isSame = false; - break; - } - } - return isSame; -} - /// default op1 is previous operation bool VectorFusionStrategy::isCompatibleVectorType(Operation *op1, Operation *op2) { @@ -3075,8 +3054,7 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { .Default([&](Operation *op) { // default is last axis dataAxis.emplace_back( - mlir::dyn_cast(op->getResultTypes().front()).getRank() - - 1); + cast(op->getResultTypes()[0]).getRank() - 1); }); } @@ -3160,17 +3138,12 @@ bool hasDataDependency(Operation *op1, Operation *op2) { }) .Case( [&](vector::MultiDimReductionOp multiReductionOp) { - // has two cases: op1 is special operation, op2 is normal - // operation op1 and op2 is both speicial operation SmallVector dims2, reductionDims, parallelDims; getOperationDataAxis(op1, reductionDims); getOperationDataAxis(op2, dims2); DenseSet checkSet(dims2.begin(), dims2.end()); auto op2VectorType = getOperationVectorType(op2); if (!isSpecialOp(op2)) { - if (isSameVectorType(op1, op2)) { - return false; - } // all reduction axis should be op2's data axis bool reduceDependent = false; for (auto x : reductionDims) { @@ -3187,11 +3160,10 @@ bool hasDataDependency(Operation *op1, Operation *op2) { checkSet.insert(reductionDims.begin(), reductionDims.end()); auto rdRank = multiReductionOp.getSourceVectorType().getRank(); - for (auto i = 0; i < rdRank; i++) { - if (!checkSet.contains(i)) { + for (auto i = 0; i < rdRank; i++) + if (not checkSet.contains(i)) parallelDims.emplace_back(i); - } - } + checkSet.clear(); checkSet.insert(parallelDims.begin(), parallelDims.end()); auto rank = op2VectorType->getRank(); @@ -3244,9 +3216,12 @@ Operation *getNotReadWriteOperaiton(std::queue &tmpQ) { } bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { + if (isa(op)) { + noNeedToJudgeOps.push(op); + return false; + } // 1. check previous operation if (!opGroups.back().empty()) { - // We only care about the calculation operation. std::queue tmpQ(opGroups.back()); Operation *prevOp = nullptr; @@ -3298,9 +3273,17 @@ void VectorFusionStrategy::addOperationToGroup(Operation *op) { if (isNeedNewGroup(op)) { opGroups.emplace_back(std::queue()); } - updateGroupBitgestVectorType(vectorType); - opGroups.back().push(op); - opGroupIndexMap[op] = opGroups.size() - 1; + if (not isa(op)) { + updateGroupBitgestVectorType(vectorType); + while (not noNeedToJudgeOps.empty()) { + auto cur = noNeedToJudgeOps.front(); + noNeedToJudgeOps.pop(); + opGroupIndexMap[cur] = opGroups.size() - 1; + opGroups.back().push(cur); + } + opGroups.back().push(op); + opGroupIndexMap[op] = opGroups.size() - 1; + } opAnchorPos[op] = getOperationMaxVectorType(op)->getRank() - 1; } diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index 1eee9b112..5fcc6b34e 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -49,6 +49,7 @@ #include #include #include +// #include "gc/Dialect/Microkernel/MicrokernelOps.h" namespace mlir { namespace gc { namespace { @@ -109,14 +110,17 @@ class TypeHelper { class VectorFusionStrategy : public TypeHelper { private: func::FuncOp func; - llvm::SmallVector, 8> opGroups; - llvm::SmallVector groupMaxSteps; + SmallVector, 8> opGroups; + SmallVector groupMaxSteps; /// vector type which has bigest rank in current operation group llvm::SmallDenseMap groupBigestRankVectorType; /// query current operation in which group, return group index - llvm::DenseMap opGroupIndexMap; + DenseMap opGroupIndexMap; /// can fused into prev operation which axis position - llvm::DenseMap opAnchorPos; + DenseMap opAnchorPos; + /// record some operations which not need to No need to judge whether can be + /// fused + std::queue noNeedToJudgeOps; public: VectorFusionStrategy() = default; @@ -468,11 +472,11 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { void generateGroupOpVectorizedIR(const int idx); /// prepare for loop iteration args - ValueRange - prepareForLoopArgs(const size_t grpIdx, - DenseMap ¤tLoopStateIdxMap, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap); + void prepareForLoopArgs(const size_t grpIdx, + DenseMap ¤tLoopStateIdxMap, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap, + SmallVector &loopArgs); /// replace original operation result with corresponding for loop result void replaceOpUsersWithForLoopResult( diff --git a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir index 2b96d2fec..4f338951c 100644 --- a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir +++ b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir @@ -28,366 +28,281 @@ // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> // CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[READ0]] : vector<16xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<11008x4096xf32> -// func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { - -// %0 = tensor.empty() : tensor<11008x4096xf32> -// %1 = linalg.add ins(%arg0, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> -// %2 = linalg.add ins(%1, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> -// return %2 : tensor<11008x4096xf32> -// } +func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { + %0 = tensor.empty() : tensor<11008x4096xf32> + %1 = linalg.add ins(%arg0, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> + %2 = linalg.add ins(%1, %arg1 : tensor<11008x4096xf32>, tensor<11008x4096xf32>) outs(%0: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> + return %2 : tensor<11008x4096xf32> +} -// // CHECK-LABEL: func @reduce_keepdimtest1 -// // CHECK: %[[C32:.*]] = arith.constant 32 : index -// // CHECK: %[[C64:.*]] = arith.constant 64 : index -// // CHECK: %[[C16:.*]] = arith.constant 16 : index -// // CHECK: %[[C1:.*]] = arith.constant 1 : index -// // CHECK: %[[C0:.*]] = arith.constant 0 : index -// // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x64xf32> -// // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<16x1x64xf32> -// // CHECK: scf.for -// // CHECK: scf.for -// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> -// // CHECK: scf.for -// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> -// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], {{.*}} : vector<16xf32> -// // CHECK: scf.yield -// // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> -// // CHECK: scf.yield -// // CHECK: scf.for -// // CHECK: scf.for -// // CHECK: scf.for -// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> -// func.func @reduce_keepdimtest1(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { -// %0 = tensor.empty() : tensor<16x64xf32> -// %reduce = linalg.reduce -// ins(%arg0:tensor<16x32x64xf32>) -// outs(%0:tensor<16x64xf32>) -// dimensions = [1] -// (%in: f32, %out: f32) { -// %1 = arith.addf %out, %in: f32 -// linalg.yield %1: f32 -// } -// %2 = tensor.expand_shape %reduce [[0],[1, 2]] output_shape [16, 1, 64] : tensor<16x64xf32> into tensor<16x1x64xf32> -// return %2 : tensor<16x1x64xf32> -// } +// CHECK-LABEL: func @reduce_keepdimtest1 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x64xf32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<16x1x64xf32> +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: scf.for +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], {{.*}} : vector<16xf32> +// CHECK: scf.yield +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> +// CHECK: scf.yield +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +func.func @reduce_keepdimtest1(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { + %0 = tensor.empty() : tensor<16x64xf32> + %reduce = linalg.reduce + ins(%arg0:tensor<16x32x64xf32>) + outs(%0:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %1 = arith.addf %out, %in: f32 + linalg.yield %1: f32 + } + %2 = tensor.expand_shape %reduce [[0],[1, 2]] output_shape [16, 1, 64] : tensor<16x64xf32> into tensor<16x1x64xf32> + return %2 : tensor<16x1x64xf32> +} -// // CHECK-LABEL: func @fc_relu_test2 -// // CHECK: %[[MATMUL:.*]] = linalg.matmul -// // CHECK: scf.for -// // CHECK: scf.for -// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> -// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> -// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> -// // CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], {{.*}} : vector<16xf32> -// // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512xf32> -// func.func @fc_relu_test2(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, -// %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) -// -> tensor<512x512xf32> { +// CHECK-LABEL: func @fc_relu_test2 +// CHECK: %[[MATMUL:.*]] = linalg.matmul +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], {{.*}} : vector<16xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512xf32> +func.func @fc_relu_test2(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, + %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) + -> tensor<512x512xf32> { -// // Matrix-matrix multiplication. -// %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) -// outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> + // Matrix-matrix multiplication. + %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> -// %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } -// ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) -// outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> -// %c0f = arith.constant 0.0 : f32 -// %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } -// ins(%biased, %c0f : tensor<512x512xf32>, f32) -// outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> -// func.return %relued : tensor<512x512xf32> -// } + %c0f = arith.constant 0.0 : f32 + %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } + ins(%biased, %c0f : tensor<512x512xf32>, f32) + outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %relued : tensor<512x512xf32> +} -// // CHECK-LABEL: func @matmul_add_test3 -// // CHECK: %[[MATMUL0:.*]] = linalg.matmul -// // CHECK: scf.for -// // CHECK: scf.for -// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> -// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> -// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> -// // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32xf32> -// // CHECK: scf.yield -// // CHECK: scf.yield -// func.func @matmul_add_test3(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { +// CHECK-LABEL: func @matmul_add_test3 +// CHECK: %[[MATMUL0:.*]] = linalg.matmul +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32xf32> +// CHECK: scf.yield +// CHECK: scf.yield +func.func @matmul_add_test3(%arg0: tensor<8192x12288xf16>, %arg1: tensor<12288x16384xf16>, %arg2: tensor<8192x16384xf32>, %arg3: tensor<8192x16384xf32>) -> tensor<8192x16384xf32> { -// %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> -// %1 = tensor.empty() : tensor<8192x16384xf32> -// %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { -// ^bb0(%in: f32, %in_0: f32, %out: f32): -// %4 = arith.addf %in, %in_0 : f32 -// linalg.yield %4 : f32 -// } -> tensor<8192x16384xf32> -// %c0 = arith.constant 0 : index -// %c8192 = arith.constant 8192 : index -// %c128 = arith.constant 128 : index -// %3 = scf.for %arg4 = %c0 to %c8192 step %c128 iter_args(%arg5 = %1) -> (tensor<8192x16384xf32>) { -// %c0_0 = arith.constant 0 : index -// %c16384 = arith.constant 16384 : index -// %c128_1 = arith.constant 128 : index -// %4 = scf.for %arg6 = %c0_0 to %c16384 step %c128_1 iter_args(%arg7 = %arg5) -> (tensor<8192x16384xf32>) { -// %extracted_slice = tensor.extract_slice %arg0[%arg4, 0] [128, 12288] [1, 1] : tensor<8192x12288xf16> to tensor<128x12288xf16> -// %extracted_slice_2 = tensor.extract_slice %arg1[0, %arg6] [12288, 128] [1, 1] : tensor<12288x16384xf16> to tensor<12288x128xf16> -// %extracted_slice_3 = tensor.extract_slice %arg2[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> -// %5 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice, %extracted_slice_2 : tensor<128x12288xf16>, tensor<12288x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf32>) -> tensor<128x128xf32> -// %extracted_slice_4 = tensor.extract_slice %0[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> -// %extracted_slice_5 = tensor.extract_slice %arg3[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> -// %extracted_slice_6 = tensor.extract_slice %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> -// %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_5 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%extracted_slice_6 : tensor<128x128xf32>) attrs = {__root_op__ = 0 : i64} { -// ^bb0(%in: f32, %in_9: f32, %out: f32): -// %8 = arith.addf %in, %in_9 : f32 -// linalg.yield %8 : f32 -// } -> tensor<128x128xf32> -// %c0_7 = arith.constant 0 : index -// %c128_8 = arith.constant 128 : index -// %c32 = arith.constant 32 : index -// %7 = scf.for %arg8 = %c0_7 to %c128_8 step %c32 iter_args(%arg9 = %extracted_slice_6) -> (tensor<128x128xf32>) { -// %c0_9 = arith.constant 0 : index -// %c128_10 = arith.constant 128 : index -// %c32_11 = arith.constant 32 : index -// %8 = scf.for %arg10 = %c0_9 to %c128_10 step %c32_11 iter_args(%arg11 = %arg9) -> (tensor<128x128xf32>) { -// %extracted_slice_12 = tensor.extract_slice %extracted_slice[%arg8, 0] [32, 12288] [1, 1] : tensor<128x12288xf16> to tensor<32x12288xf16> -// %extracted_slice_13 = tensor.extract_slice %extracted_slice_2[0, %arg10] [12288, 32] [1, 1] : tensor<12288x128xf16> to tensor<12288x32xf16> -// %extracted_slice_14 = tensor.extract_slice %extracted_slice_3[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> -// %9 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice_12, %extracted_slice_13 : tensor<32x12288xf16>, tensor<12288x32xf16>) outs(%extracted_slice_14 : tensor<32x32xf32>) -> tensor<32x32xf32> -// %extracted_slice_15 = tensor.extract_slice %5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> -// %extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> -// %extracted_slice_17 = tensor.extract_slice %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> -// %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %extracted_slice_16 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice_17 : tensor<32x32xf32>) attrs = {__root_op__ = 0 : i64} { -// ^bb0(%in: f32, %in_19: f32, %out: f32): -// %11 = arith.addf %in, %in_19 : f32 -// linalg.yield %11 : f32 -// } -> tensor<32x32xf32> -// %inserted_slice_18 = tensor.insert_slice %10 into %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x128xf32> -// scf.yield %inserted_slice_18 : tensor<128x128xf32> -// } {__parallel_loop__ = 1 : i64} -// scf.yield %8 : tensor<128x128xf32> -// } {__parallel_loop__ = 1 : i64} -// %inserted_slice = tensor.insert_slice %7 into %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<8192x16384xf32> -// scf.yield %inserted_slice : tensor<8192x16384xf32> -// } {__parallel_loop__ = 0 : i64} -// scf.yield %4 : tensor<8192x16384xf32> -// } {__parallel_loop__ = 0 : i64} -// return %3 : tensor<8192x16384xf32> -// } - -// // CHECK-LABEL: func @fuse_mlp_test4 -// #map = affine_map<(d0) -> (d0 * 64)> -// #map1 = affine_map<(d0) -> (d0 * 128)> -// #map2 = affine_map<(d0) -> (d0 * 4)> -// #map3 = affine_map<(d0) -> (d0 floordiv 16)> -// #map4 = affine_map<(d0) -> (d0 floordiv 32)> -// func.func @fuse_mlp_test4(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { -// %c32 = arith.constant 32 : index -// %c512 = arith.constant 512 : index -// %c128 = arith.constant 128 : index -// %c64 = arith.constant 64 : index -// %c0 = arith.constant 0 : index -// %cst = arith.constant 0.000000e+00 : bf16 -// %0 = tensor.empty() : tensor<128x256xbf16> -// %1 = tensor.empty() : tensor<512x256xbf16> -// %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { -// %3 = affine.apply #map(%arg3) -// %4 = affine.apply #map1(%arg4) -// %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> -// %5 = affine.apply #map2(%arg4) -// %extracted_slice_0 = tensor.extract_slice %arg1[0, %5, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x4x16x32xbf16> -// %extracted_slice_1 = tensor.extract_slice %1[0, %4] [512, 128] [1, 1] : tensor<512x256xbf16> to tensor<512x128xbf16> -// %extracted_slice_2 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> -// %extracted_slice_3 = tensor.extract_slice %arg2[%4] [128] [1] : tensor<256xbf16> to tensor<128xbf16> -// %extracted_slice_4 = tensor.extract_slice %0[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> -// %extracted_slice_5 = tensor.extract_slice %arg6[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> -// %extracted_slice_6 = tensor.extract_slice %arg7[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> -// %6:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_2, %arg10 = %extracted_slice_5, %arg11 = %extracted_slice_6) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %7:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %8:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %extracted_slice_7 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> -// %9 = affine.apply #map3(%arg16) -// %10 = affine.apply #map4(%arg12) -// %extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%9, %10, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x4x16x32xbf16> -// %extracted_slice_9 = tensor.extract_slice %extracted_slice_1[%arg16, %arg12] [512, 128] [1, 1] : tensor<512x128xbf16> to tensor<512x128xbf16> -// %extracted_slice_10 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> -// %extracted_slice_11 = tensor.extract_slice %extracted_slice_3[%arg12] [128] [1] : tensor<128xbf16> to tensor<128xbf16> -// %extracted_slice_12 = tensor.extract_slice %extracted_slice_4[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> -// %extracted_slice_13 = tensor.extract_slice %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> -// %extracted_slice_14 = tensor.extract_slice %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> -// %11:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_10, %arg22 = %extracted_slice_13, %arg23 = %extracted_slice_14) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %12:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %13:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { -// %extracted_slice_17 = tensor.extract_slice %extracted_slice_7[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> -// %14 = affine.apply #map3(%arg28) -// %15 = affine.apply #map4(%arg24) -// %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%14, %15, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> -// %extracted_slice_19 = tensor.extract_slice %extracted_slice_9[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> -// %unpack = tensor.unpack %extracted_slice_18 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_19 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> -// // CHECK: %[[C16:.*]] = arith.constant 16 : index -// // CHECK: %[[C1:.*]] = arith.constant 1 : index -// // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : bf16 -// // CHECK: %[[C0:.*]] = arith.constant 0 : index -// // CHECK: %[[C64:.*]] = arith.constant 64 : index -// // CHECK: %[[C128:.*]] = arith.constant 128 : index -// // CHECK: %[[C512:.*]] = arith.constant 512 : index -// // CHECK: %[[C32:.*]] = arith.constant 32 : index -// // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<128x256xbf16> -// // CHECK: scf.forall -// // CHECK-COUNT-6: scf.for -// // CHECK-COUNT-4: scf.for -// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x1x16x32xbf16>, vector<1xbf16> -// // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xbf16>, tensor<32x16x1x32xbf16> -// // CHECK: %[[FILL0:.*]] = linalg.fill -// // CHECK-COUNT-3: scf.for -// // CHECK: %[[APPLY0:.*]] = affine.apply -// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x512xbf16>, vector<32xbf16> -// // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x32x512xbf16> -// // CHECK-COUNT-4: scf.for -// // CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x16x1x32xbf16>, vector<32xbf16> -// // CHECK: %[[APPLY1:.*]] = affine.apply -// // CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x512x32xbf16> -// // CHECK: %[[MATMUL0:.*]] = linalg.batch_reduce_matmul -// // CHECK-COUNT-2: scf.for -// // CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xbf16>, vector<32xbf16> -// // CHECK: %[[READ4:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32xbf16>, vector<32xbf16> -// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ3]], %[[READ4]] : vector<32xbf16> -// // CHECK: %[[EXP0:.*]] = math.exp %[[ADD0]] : vector<32xbf16> -// // CHECK: %[[WRITE3:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> -// // CHECK: %[[WRITE4:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> -// %extracted_slice_20 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %16 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %expanded = tensor.expand_shape %extracted_slice_17 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> -// %expanded_21 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> -// %17 = linalg.batch_reduce_matmul ins(%expanded, %expanded_21 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%16 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %extracted_slice_22 = tensor.extract_slice %extracted_slice_11[%arg24] [32] [1] : tensor<128xbf16> to tensor<32xbf16> -// %extracted_slice_23 = tensor.extract_slice %extracted_slice_12[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %broadcasted = linalg.broadcast ins(%extracted_slice_22 : tensor<32xbf16>) outs(%extracted_slice_23 : tensor<32x32xbf16>) dimensions = [0] -// %extracted_slice_24 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %18 = linalg.add ins(%17, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_24 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %inserted_slice_25 = tensor.insert_slice %17 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// %extracted_slice_26 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> -// %19 = linalg.exp ins(%18 : tensor<32x32xbf16>) outs(%extracted_slice_26 : tensor<32x32xbf16>) -> tensor<32x32xbf16> -// %inserted_slice_27 = tensor.insert_slice %18 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// %inserted_slice_28 = tensor.insert_slice %19 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> -// scf.yield %inserted_slice_25, %inserted_slice_27, %inserted_slice_28 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %13#0, %13#1, %13#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %12#0, %12#1, %12#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// %inserted_slice = tensor.insert_slice %11#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// %inserted_slice_15 = tensor.insert_slice %11#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// %inserted_slice_16 = tensor.insert_slice %11#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> -// scf.yield %inserted_slice, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %8#0, %8#1, %8#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> -// } -// scf.forall.in_parallel { -// tensor.parallel_insert_slice %6#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// tensor.parallel_insert_slice %6#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// tensor.parallel_insert_slice %6#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> -// } -// } -// return %2#2 : tensor<128x256xbf16> -// } - -// // CHECK-LABEL: func @elem_pack_transpose_inner_dims_test5 -// // CHECK: %[[C32:.*]] = arith.constant 32 : index -// // CHECK: %[[C4:.*]] = arith.constant 4 : index -// // CHECK: %[[C256:.*]] = arith.constant 256 : index -// // CHECK: %[[C16:.*]] = arith.constant 16 : index -// // CHECK: %[[C128:.*]] = arith.constant 128 : index -// // CHECK: %[[C1:.*]] = arith.constant 1 : index -// // CHECK: %[[C0:.*]] = arith.constant 0 : index -// // CHECK: %[[C32I32:.*]] = arith.constant 0 : i32 -// // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x32x16x16xi32> -// // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<128x256xi32> -// // CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<4x16x16x32xi32> -// // CHECK: scf.for -// // CHECK: scf.for -// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> -// // CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> -// // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<128x256xi32> -// // CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<4x32x16x16xi32>) -// // CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<4x32x16x16xi32>) -// // CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<4x32x16x16xi32>) -// // CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<4x32x16x16xi32>) -// // CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg2]], %[[arg4]]] -// // CHECK: %[[APPLY1:.*]] = affine.apply #[[map8]]()[%[[arg6]], %[[arg8]]] -// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> -// // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<4x32x16x16xi32> -// // CHECK-COUNT-4: scf.for -// // CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> -// // CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<4x16x16x32xi32> -// #map5 = affine_map<(d0, d1) -> (d0, d1)> -// func.func @elem_pack_transpose_inner_dims_test5(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ -// %init = tensor.empty() : tensor<128x256xi32> -// %elem = linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel"]} -// ins(%arg0 : tensor<128x256xi32>) -// outs(%init : tensor<128x256xi32>) { -// ^bb0(%arg3: i32, %arg4: i32): -// %4 = arith.addi %arg3, %arg3 : i32 -// linalg.yield %4 : i32 -// } -> tensor<128x256xi32> -// %pack = tensor.pack %elem -// inner_dims_pos = [1, 0] -// inner_tiles = [16, 32] -// into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> -// return %pack : tensor<4x16x16x32xi32> -// } - -// #map6 = affine_map<(d0, d1) -> (d0, d1)> -// func.func @elem_pack_transpose_outer_dims_test6(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{ -// %init = tensor.empty() : tensor<128x256xi32> -// %elem = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} -// ins(%arg0 : tensor<128x256xi32>) -// outs(%init : tensor<128x256xi32>) { -// ^bb0(%arg3: i32, %arg4: i32): -// %4 = arith.addi %arg3, %arg3 : i32 -// linalg.yield %4 : i32 -// } -> tensor<128x256xi32> -// %pack = tensor.pack %elem -// outer_dims_perm = [1, 0] -// inner_dims_pos = [0, 1] -// inner_tiles = [32, 16] -// into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> -// return %pack : tensor<16x4x32x16xi32> -// } - + %0 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%arg0, %arg1 : tensor<8192x12288xf16>, tensor<12288x16384xf16>) outs(%arg2 : tensor<8192x16384xf32>) -> tensor<8192x16384xf32> + %1 = tensor.empty() : tensor<8192x16384xf32> + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %arg3 : tensor<8192x16384xf32>, tensor<8192x16384xf32>) outs(%1 : tensor<8192x16384xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<8192x16384xf32> + %c0 = arith.constant 0 : index + %c8192 = arith.constant 8192 : index + %c128 = arith.constant 128 : index + %3 = scf.for %arg4 = %c0 to %c8192 step %c128 iter_args(%arg5 = %1) -> (tensor<8192x16384xf32>) { + %c0_0 = arith.constant 0 : index + %c16384 = arith.constant 16384 : index + %c128_1 = arith.constant 128 : index + %4 = scf.for %arg6 = %c0_0 to %c16384 step %c128_1 iter_args(%arg7 = %arg5) -> (tensor<8192x16384xf32>) { + %extracted_slice = tensor.extract_slice %arg0[%arg4, 0] [128, 12288] [1, 1] : tensor<8192x12288xf16> to tensor<128x12288xf16> + %extracted_slice_2 = tensor.extract_slice %arg1[0, %arg6] [12288, 128] [1, 1] : tensor<12288x16384xf16> to tensor<12288x128xf16> + %extracted_slice_3 = tensor.extract_slice %arg2[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %5 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice, %extracted_slice_2 : tensor<128x12288xf16>, tensor<12288x128xf16>) outs(%extracted_slice_3 : tensor<128x128xf32>) -> tensor<128x128xf32> + %extracted_slice_4 = tensor.extract_slice %0[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %extracted_slice_5 = tensor.extract_slice %arg3[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %extracted_slice_6 = tensor.extract_slice %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<8192x16384xf32> to tensor<128x128xf32> + %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%5, %extracted_slice_5 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%extracted_slice_6 : tensor<128x128xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_9: f32, %out: f32): + %8 = arith.addf %in, %in_9 : f32 + linalg.yield %8 : f32 + } -> tensor<128x128xf32> + %c0_7 = arith.constant 0 : index + %c128_8 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %7 = scf.for %arg8 = %c0_7 to %c128_8 step %c32 iter_args(%arg9 = %extracted_slice_6) -> (tensor<128x128xf32>) { + %c0_9 = arith.constant 0 : index + %c128_10 = arith.constant 128 : index + %c32_11 = arith.constant 32 : index + %8 = scf.for %arg10 = %c0_9 to %c128_10 step %c32_11 iter_args(%arg11 = %arg9) -> (tensor<128x128xf32>) { + %extracted_slice_12 = tensor.extract_slice %extracted_slice[%arg8, 0] [32, 12288] [1, 1] : tensor<128x12288xf16> to tensor<32x12288xf16> + %extracted_slice_13 = tensor.extract_slice %extracted_slice_2[0, %arg10] [12288, 32] [1, 1] : tensor<12288x128xf16> to tensor<12288x32xf16> + %extracted_slice_14 = tensor.extract_slice %extracted_slice_3[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %9 = linalg.matmul {__fused_op__ = [0], name = "dot", tile_sizes = [[128, 128], [32, 32]]} ins(%extracted_slice_12, %extracted_slice_13 : tensor<32x12288xf16>, tensor<12288x32xf16>) outs(%extracted_slice_14 : tensor<32x32xf32>) -> tensor<32x32xf32> + %extracted_slice_15 = tensor.extract_slice %5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %extracted_slice_16 = tensor.extract_slice %extracted_slice_5[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %extracted_slice_17 = tensor.extract_slice %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<128x128xf32> to tensor<32x32xf32> + %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %extracted_slice_16 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice_17 : tensor<32x32xf32>) attrs = {__root_op__ = 0 : i64} { + ^bb0(%in: f32, %in_19: f32, %out: f32): + %11 = arith.addf %in, %in_19 : f32 + linalg.yield %11 : f32 + } -> tensor<32x32xf32> + %inserted_slice_18 = tensor.insert_slice %10 into %arg11[%arg8, %arg10] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<128x128xf32> + scf.yield %inserted_slice_18 : tensor<128x128xf32> + } {__parallel_loop__ = 1 : i64} + scf.yield %8 : tensor<128x128xf32> + } {__parallel_loop__ = 1 : i64} + %inserted_slice = tensor.insert_slice %7 into %arg7[%arg4, %arg6] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<8192x16384xf32> + scf.yield %inserted_slice : tensor<8192x16384xf32> + } {__parallel_loop__ = 0 : i64} + scf.yield %4 : tensor<8192x16384xf32> + } {__parallel_loop__ = 0 : i64} + return %3 : tensor<8192x16384xf32> +} -// #map7 = affine_map<(d0, d1) -> (d0, d1)> -// func.func @elem_pack_transpose_inner_and_outer_dims_test7(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{ -// %init = tensor.empty() : tensor<128x256xi32> -// %elem = linalg.generic {indexing_maps = [#map7, #map7], iterator_types = ["parallel", "parallel"]} -// ins(%arg0 : tensor<128x256xi32>) -// outs(%init : tensor<128x256xi32>) { -// ^bb0(%arg3: i32, %arg4: i32): -// %4 = arith.addi %arg3, %arg3 : i32 -// linalg.yield %4 : i32 -// } -> tensor<128x256xi32> -// %pack = tensor.pack %elem -// outer_dims_perm = [1, 0] -// inner_dims_pos = [1, 0] -// inner_tiles = [16, 32] -// into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> -// return %pack : tensor<16x4x16x32xi32> -// } +// CHECK-LABEL: func @fuse_mlp_test4 +#map = affine_map<(d0) -> (d0 * 64)> +#map1 = affine_map<(d0) -> (d0 * 128)> +#map2 = affine_map<(d0) -> (d0 * 4)> +#map3 = affine_map<(d0) -> (d0 floordiv 16)> +#map4 = affine_map<(d0) -> (d0 floordiv 32)> +func.func @fuse_mlp_test4(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = tensor.empty() : tensor<512x256xbf16> + %2:3 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %0, %arg6 = %0, %arg7 = %0) -> (tensor<128x256xbf16>, tensor<128x256xbf16>, tensor<128x256xbf16>) { + %3 = affine.apply #map(%arg3) + %4 = affine.apply #map1(%arg4) + %extracted_slice = tensor.extract_slice %arg0[%3, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> + %5 = affine.apply #map2(%arg4) + %extracted_slice_0 = tensor.extract_slice %arg1[0, %5, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x8x16x32xbf16> to tensor<32x4x16x32xbf16> + %extracted_slice_1 = tensor.extract_slice %1[0, %4] [512, 128] [1, 1] : tensor<512x256xbf16> to tensor<512x128xbf16> + %extracted_slice_2 = tensor.extract_slice %arg5[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_3 = tensor.extract_slice %arg2[%4] [128] [1] : tensor<256xbf16> to tensor<128xbf16> + %extracted_slice_4 = tensor.extract_slice %0[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_5 = tensor.extract_slice %arg6[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %extracted_slice_6 = tensor.extract_slice %arg7[%3, %4] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + %6:3 = scf.for %arg8 = %c0 to %c64 step %c64 iter_args(%arg9 = %extracted_slice_2, %arg10 = %extracted_slice_5, %arg11 = %extracted_slice_6) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %7:3 = scf.for %arg12 = %c0 to %c128 step %c128 iter_args(%arg13 = %arg9, %arg14 = %arg10, %arg15 = %arg11) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %8:3 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg13, %arg18 = %arg14, %arg19 = %arg15) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_7 = tensor.extract_slice %extracted_slice[%arg8, %arg16] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> + %9 = affine.apply #map3(%arg16) + %10 = affine.apply #map4(%arg12) + %extracted_slice_8 = tensor.extract_slice %extracted_slice_0[%9, %10, 0, 0] [32, 4, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x4x16x32xbf16> + %extracted_slice_9 = tensor.extract_slice %extracted_slice_1[%arg16, %arg12] [512, 128] [1, 1] : tensor<512x128xbf16> to tensor<512x128xbf16> + %extracted_slice_10 = tensor.extract_slice %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_11 = tensor.extract_slice %extracted_slice_3[%arg12] [128] [1] : tensor<128xbf16> to tensor<128xbf16> + %extracted_slice_12 = tensor.extract_slice %extracted_slice_4[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_13 = tensor.extract_slice %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %extracted_slice_14 = tensor.extract_slice %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + %11:3 = scf.for %arg20 = %c0 to %c64 step %c32 iter_args(%arg21 = %extracted_slice_10, %arg22 = %extracted_slice_13, %arg23 = %extracted_slice_14) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %12:3 = scf.for %arg24 = %c0 to %c128 step %c32 iter_args(%arg25 = %arg21, %arg26 = %arg22, %arg27 = %arg23) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %13:3 = scf.for %arg28 = %c0 to %c512 step %c512 iter_args(%arg29 = %arg25, %arg30 = %arg26, %arg31 = %arg27) -> (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) { + %extracted_slice_17 = tensor.extract_slice %extracted_slice_7[%arg20, %arg28] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> + %14 = affine.apply #map3(%arg28) + %15 = affine.apply #map4(%arg24) + %extracted_slice_18 = tensor.extract_slice %extracted_slice_8[%14, %15, 0, 0] [32, 1, 16, 32] [1, 1, 1, 1] : tensor<32x4x16x32xbf16> to tensor<32x1x16x32xbf16> + %extracted_slice_19 = tensor.extract_slice %extracted_slice_9[%arg28, %arg24] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> + %unpack = tensor.unpack %extracted_slice_18 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %extracted_slice_19 : tensor<32x1x16x32xbf16> -> tensor<512x32xbf16> +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C512:.*]] = arith.constant 512 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<128x256xbf16> +// CHECK: scf.forall +// CHECK-COUNT-6: scf.for +// CHECK-COUNT-4: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x1x16x32xbf16>, vector<1xbf16> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xbf16>, tensor<32x16x1x32xbf16> +// CHECK: %[[FILL0:.*]] = linalg.fill +// CHECK-COUNT-3: scf.for +// CHECK: %[[APPLY0:.*]] = affine.apply +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x512xbf16>, vector<32xbf16> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x32x512xbf16> +// CHECK-COUNT-4: scf.for +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x16x1x32xbf16>, vector<32xbf16> +// CHECK: %[[APPLY1:.*]] = affine.apply +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x512x32xbf16> +// CHECK: %[[MATMUL0:.*]] = linalg.batch_reduce_matmul +// CHECK-COUNT-2: scf.for +// CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xbf16>, vector<32xbf16> +// CHECK: %[[READ4:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32xbf16>, vector<32xbf16> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ3]], %[[READ4]] : vector<32xbf16> +// CHECK: %[[EXP0:.*]] = math.exp %[[ADD0]] : vector<32xbf16> +// CHECK: %[[WRITE3:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> +// CHECK: %[[WRITE4:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> + %extracted_slice_20 = tensor.extract_slice %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %16 = linalg.fill ins(%cst : bf16) outs(%extracted_slice_20 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %expanded = tensor.expand_shape %extracted_slice_17 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> + %expanded_21 = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> + %17 = linalg.batch_reduce_matmul ins(%expanded, %expanded_21 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%16 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %extracted_slice_22 = tensor.extract_slice %extracted_slice_11[%arg24] [32] [1] : tensor<128xbf16> to tensor<32xbf16> + %extracted_slice_23 = tensor.extract_slice %extracted_slice_12[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %broadcasted = linalg.broadcast ins(%extracted_slice_22 : tensor<32xbf16>) outs(%extracted_slice_23 : tensor<32x32xbf16>) dimensions = [0] + %extracted_slice_24 = tensor.extract_slice %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %18 = linalg.add ins(%17, %broadcasted : tensor<32x32xbf16>, tensor<32x32xbf16>) outs(%extracted_slice_24 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_25 = tensor.insert_slice %17 into %arg29[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %extracted_slice_26 = tensor.extract_slice %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + %19 = linalg.exp ins(%18 : tensor<32x32xbf16>) outs(%extracted_slice_26 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + %inserted_slice_27 = tensor.insert_slice %18 into %arg30[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + %inserted_slice_28 = tensor.insert_slice %19 into %arg31[%arg20, %arg24] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice_25, %inserted_slice_27, %inserted_slice_28 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %13#0, %13#1, %13#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %12#0, %12#1, %12#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + %inserted_slice = tensor.insert_slice %11#0 into %arg17[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_15 = tensor.insert_slice %11#1 into %arg18[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + %inserted_slice_16 = tensor.insert_slice %11#2 into %arg19[%arg8, %arg12] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + scf.yield %inserted_slice, %inserted_slice_15, %inserted_slice_16 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %8#0, %8#1, %8#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.yield %7#0, %7#1, %7#2 : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + } + scf.forall.in_parallel { + tensor.parallel_insert_slice %6#2 into %arg7[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %6#1 into %arg6[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + tensor.parallel_insert_slice %6#0 into %arg5[%3, %4] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + } + } + return %2#2 : tensor<128x256xbf16> + } -// CHECK-LABEL: func @elem_pack_transpose_inner_and_outer_dims2_test8 +// CHECK-LABEL: func @elem_pack_transpose_inner_dims_test5 // CHECK: %[[C32:.*]] = arith.constant 32 : index -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index -// CHECK: %[[C57:.*]] = arith.constant 57 : index -// CHECK: %[[C56:.*]] = arith.constant 56 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<1x56x57x2x32xf32> -// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<1x56x57x64xf32> -// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x2x56x57x32xf32> +// CHECK: %[[C32I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x32x16x16xi32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<128x256xi32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<4x16x16x32xi32> // CHECK: scf.for -// CHECK: // CHECK: scf.for // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> // CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> @@ -403,6 +318,150 @@ // CHECK-COUNT-4: scf.for // CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> // CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<4x16x16x32xi32> +#map5 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_inner_dims_test5(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> + return %pack : tensor<4x16x16x32xi32> +} + +// CHECK-LABEL: func @elem_pack_transpose_outer_dims_test6 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x32x16x16xi32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<128x256xi32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK-COUNT-2: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<128x256xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<4x32x16x16xi32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg2]], %[[arg4]]] +// CHECK: %[[APPLY1:.*]] = affine.apply #[[map8]]()[%[[arg6]], %[[arg8]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<4x32x16x16xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY2]]) -> (tensor<16x4x32x16xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x4x32x16xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<16x4x32x16xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<16x4x32x16xi32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<16x4x32x16xi32> +#map6 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_outer_dims_test6(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] + into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> + return %pack : tensor<16x4x32x16xi32> +} + +// CHECK-LABEL: func @elem_pack_transpose_inner_and_outer_dims_test7 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0I32:.*]] = arith.constant 0 : i32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x32x16x16xi32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<128x256xi32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<16x4x16x32xi32> +// CHECK-COUNT-2: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[ADD0:.*]] = arith.addi %[[READ0]], %[[READ0]] : vector<16xi32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<128x256xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<4x32x16x16xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<4x32x16x16xi32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg2]], %[[arg4]]] +// CHECK: %[[APPLY1:.*]] = affine.apply #[[map8]]()[%[[arg6]], %[[arg8]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<128x256xi32>, vector<16xi32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<4x32x16x16xi32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY2]]) -> (tensor<16x4x16x32xi32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x4x16x32xi32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<16x4x16x32xi32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<16x4x16x32xi32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<16x4x16x32xi32> +#map7 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_inner_and_outer_dims_test7(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map7, #map7], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> + return %pack : tensor<16x4x16x32xi32> +} + +// CHECK-LABEL: func @elem_pack_transpose_inner_and_outer_dims2_test8 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C57:.*]] = arith.constant 57 : index +// CHECK: %[[C56:.*]] = arith.constant 56 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<1x56x57x2x32xf32> +// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<1x56x57x64xf32> +// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x2x56x57x32xf32> +// CHECK-COUNT-4: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<64xf32>, vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<1x56x57x64xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C1]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C56]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C57]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: scf.for %[[arg10:.*]] = %[[C0]] to %[[C32]] step %[[C16]] iter_args(%[[arg11:.*]] = %[[arg9]]) -> (tensor<1x56x57x2x32xf32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map7]]()[%[[arg8]], %[[arg10]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<1x56x57x64xf32>, vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<1x56x57x2x32xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C1]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY2]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C56]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C57]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: scf.for %[[arg10:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg11:.*]] = %[[arg9]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<1x56x57x2x32xf32>, vector<1xf32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xf32>, tensor<1x2x56x57x32xf32> #map8 = affine_map<(d0, d1, d2, d3) -> (d3)> #map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func.func @elem_pack_transpose_inner_and_outer_dims2_test8(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> { @@ -420,21 +479,120 @@ func.func @elem_pack_transpose_inner_and_outer_dims2_test8(%arg0: tensor<64xf32> } -// // CHECK-LABEL: func @broadcast_same_shape_test9 -// // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// // CHECK: %[[C0:.*]] = arith.constant 0 : index -// // CHECK: %[[C1:.*]] = arith.constant 1 : index -// // CHECK: %[[C2:.*]] = arith.constant 2 : index -// // CHECK: %[[C16:.*]] = arith.constant 16 : index -// // CHECK: scf.for -// // CHECK: scf.for -// // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<2x16xf32>, vector<16xf32> -// // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16xf32>, vector<16xf32> -// // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> -// // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<2x16xf32> -// func.func @broadcast_same_shape_test9(%input: tensor<16xf32>, %init: tensor<2x16xf32>) -> tensor<2x16xf32> { -// %empty = tensor.empty() : tensor<2x16xf32> -// %0 = linalg.broadcast ins(%input: tensor<16xf32>) outs(%empty: tensor<2x16xf32>) dimensions = [0] -// %1 = linalg.add ins(%0, %init : tensor<2x16xf32>, tensor<2x16xf32>) outs(%init : tensor<2x16xf32>) -> tensor<2x16xf32> -// return %1 : tensor<2x16xf32> -// } +// CHECK-LABEL: func @broadcast_same_shape_test9 +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<2x16xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<2x16xf32> +func.func @broadcast_same_shape_test9(%input: tensor<16xf32>, %init: tensor<2x16xf32>) -> tensor<2x16xf32> { + %empty = tensor.empty() : tensor<2x16xf32> + %0 = linalg.broadcast ins(%input: tensor<16xf32>) outs(%empty: tensor<2x16xf32>) dimensions = [0] + %1 = linalg.add ins(%0, %init : tensor<2x16xf32>, tensor<2x16xf32>) outs(%init : tensor<2x16xf32>) -> tensor<2x16xf32> + return %1 : tensor<2x16xf32> +} + +// CHECK-LABEL: func @reduce_single_test10 +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %arg1) -> (tensor<16x64xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x64xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ0]]) -> (vector<16xf32>) +// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[arg6]], %[[arg4]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[arg7]] : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> +func.func @reduce_single_test10(%input: tensor<16x32x64xf32>, + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + %reduce = linalg.reduce + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %0 = arith.addf %out, %in: f32 + linalg.yield %0: f32 + } + func.return %reduce : tensor<16x64xf32> +} + +// CHECK-LABEL: func @reduce_fusePostOp_test11 +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x32x64xf32> +// CHECK-COUNT-3: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x32x64xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %arg1) -> (tensor<16x64xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x64xf32>) +// CHECK: %[[READ1:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ1]]) -> (vector<16xf32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}[%[[arg2]], %[[arg6]], %[[arg4]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ2]], %[[arg7]] : vector<16xf32> +// CHECK: %[[MUL:.*]] = arith.mulf {{.*}}, {{.*}} : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> +func.func @reduce_fusePostOp_test11(%input: tensor<16x32x64xf32>, + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + %0 = linalg.add ins(%input, %input : tensor<16x32x64xf32>,tensor<16x32x64xf32>) + outs(%input : tensor<16x32x64xf32>) -> tensor<16x32x64xf32> + %reduce = linalg.reduce + ins(%0:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %2 = arith.addf %out, %in: f32 + linalg.yield %2: f32 + } + %1 = linalg.mul ins(%reduce, %reduce : tensor<16x64xf32>, tensor<16x64xf32>) outs(%init: tensor<16x64xf32>) -> tensor<16x64xf32> + func.return %1 : tensor<16x64xf32> +} + +// CHECK-LABEL: func @reduce_fuse_test12 +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %arg1) -> (tensor<16x32xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x32xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST_0]] {in_bounds = [true]} : tensor<16x32xf32>, vector<16xf32> +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ0]]) -> (vector<16xf32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[CST]]) -> (vector<16xf32>) +// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[arg4]], %[[arg6]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ1]] : vector<16xf32> +// CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[arg9]] : vector<16xf32> +// CHECK: %[[REDUCTION:.*]] = vector.reduction , {{.*}} : vector<16xf32> into f32 +// CHECK: %[[INSERT:.*]] = vector.insert %[[REDUCTION]], %[[arg7]] [%[[arg6]]] : f32 into vector<16xf32> +// CHECK: %[[MUL:.*]] = arith.mulf {{.*}}, {{.*}} : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<16x32xf32> +func.func @reduce_fuse_test12(%input: tensor<16x32x64xf32>, + %init: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.add ins(%input, %input : tensor<16x32x64xf32>,tensor<16x32x64xf32>) + outs(%input : tensor<16x32x64xf32>) -> tensor<16x32x64xf32> + %reduce = linalg.reduce + ins(%0:tensor<16x32x64xf32>) + outs(%init:tensor<16x32xf32>) + dimensions = [2] + (%in: f32, %out: f32) { + %2 = arith.addf %out, %in: f32 + linalg.yield %2: f32 + } + %1 = linalg.mul ins(%reduce, %reduce : tensor<16x32xf32>, tensor<16x32xf32>) outs(%init: tensor<16x32xf32>) -> tensor<16x32xf32> + func.return %1 : tensor<16x32xf32> +} From 562657c37927a518b71bfb18cef384b7d2a34394 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 3 Sep 2024 15:56:08 +0800 Subject: [PATCH 36/66] simplify reduction parallel generate for loop --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 490 ++++++++---------- lib/gc/Transforms/TilingVector.h | 31 +- .../gc/Transforms/cpu-vetor-distribution.mlir | 2 +- 3 files changed, 228 insertions(+), 295 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index e12e6fea6..b42e22f47 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -86,21 +86,18 @@ template , T>> static bool isOperationsHasDefUseRelation(Operation *op1, Operation *op2) { - for (Value opd : op2->getOperands()) { - if (opd.getDefiningOp() == op1) { + for (Value opd : op2->getOperands()) + if (opd.getDefiningOp() == op1) return true; - } - } return false; } /// Get the index position of the first element that is true static size_t getFirstTrueIndex(ArrayRef ararys) { - for (size_t i = 0; i < ararys.size(); i++) { - if (!ararys[i]) { + for (size_t i = 0; i < ararys.size(); i++) + if (!ararys[i]) return i; - } - } + return -1; } @@ -152,25 +149,22 @@ bool isNotSupportOperation(Operation *op) { /// whether operation is operate on dynamic shape bool hasDynamicShape(Operation *op) { auto isDynamicShapedType = [](Value x) { - if (auto type = dyn_cast(x.getType())) { - if (ShapedType::isDynamicShape(type.getShape())) { + if (auto type = dyn_cast(x.getType())) + if (ShapedType::isDynamicShape(type.getShape())) return true; - } - } + return false; }; // Check operands data type. - for (auto x : op->getOperands()) { - if (isDynamicShapedType(x)) { + for (auto x : op->getOperands()) + if (isDynamicShapedType(x)) return true; - } - } + // Check results data type. - for (auto x : op->getResults()) { - if (isDynamicShapedType(x)) { + for (auto x : op->getResults()) + if (isDynamicShapedType(x)) return true; - } - } + return false; } @@ -219,12 +213,12 @@ int TypeHelper::generateValidSteps(int steps, VectorType type) { auto typebits = type.getElementTypeBitWidth(); const int favx512bits = 512; const int favx2bits = 256; - if (HWInfo.favx512f) { + if (HWInfo.favx512f) return favx512bits / typebits; - } - if (HWInfo.favx2) { + + if (HWInfo.favx2) return favx2bits / typebits; - } + // invalid hardware LDBG("Please check the hardware information."); assert(false && "Invalid hardware."); @@ -332,9 +326,9 @@ mlir::FailureOr getOperationVectorType(Operation *op, /// prev-op, may need to use result vectortype /// default will return the opeation result type mlir::FailureOr getOperationMaxVectorType(Operation *op) { - if (!op) { + if (not op) return failure(); - } + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; auto ret = TypeSwitch>(op) @@ -342,10 +336,10 @@ mlir::FailureOr getOperationMaxVectorType(Operation *op) { [&](vector::TransferWriteOp transferWriteOp) -> mlir::FailureOr { auto retType = - dyn_cast(transferWriteOp.getOperandTypes()[0]); - if (retType) { + cast(transferWriteOp.getOperandTypes()[0]); + if (retType) return retType; - } + LDBG("TransferWrite Operation has wrong vector to write."); return failure(); }) @@ -359,23 +353,22 @@ mlir::FailureOr getOperationMaxVectorType(Operation *op) { return cast(multiReductionOp.getSourceVectorType()); }) .Default([&](Operation *op) -> mlir::FailureOr { - if (op->getResultTypes().empty() and - op->getOperandTypes().empty()) { + if (op->getResultTypes().empty() and op->getOperandTypes().empty()) return failure(); - } - if (op->getResultTypes().empty()) { - return dyn_cast(op->getOperandTypes()[0]); - } - if (op->getOperandTypes().empty()) { - return dyn_cast(op->getResultTypes()[0]); - } - auto opdType = dyn_cast(op->getOperandTypes()[0]); - auto retType = dyn_cast(op->getResultTypes()[0]); + + if (op->getResultTypes().empty()) + return cast(op->getOperandTypes()[0]); + + if (op->getOperandTypes().empty()) + return cast(op->getResultTypes()[0]); + + auto opdType = cast(op->getOperandTypes()[0]); + auto retType = cast(op->getResultTypes()[0]); return opdType.getRank() > retType.getRank() ? opdType : retType; }); - if (!failed(ret) and isDynamicType(ret.value())) { + if (!failed(ret) and isDynamicType(ret.value())) return failure(); - } + return ret; } @@ -389,9 +382,9 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loop_step) { return VectorType(); } auto vectorizedType = baseType.value(); - if (loop_step == 0) { + if (loop_step == 0) loop_step = getDataTypeValidSteps(vectorizedType); - } + return VectorType::get({loop_step}, vectorizedType.getElementType()); } @@ -516,22 +509,23 @@ float bfloat2float(uint16_t bfloatBits) { } bool isReadWriteOnLastDim(Operation *op) { - if (isa(op)) { - auto permutationMap = + if (isReadOrWriteOperation(op)) { + AffineMap permutationMap = dyn_cast(op) - ? dyn_cast(op).getPermutationMap() - : dyn_cast(op).getPermutationMap(); - auto rank = + ? cast(op).getPermutationMap() + : cast(op).getPermutationMap(); + int64_t rank = dyn_cast(op) - ? dyn_cast(op->getOperand(0).getType()).getRank() - : dyn_cast(op->getOperand(1).getType()).getRank(); - auto dimExpr = permutationMap.getResults(); + ? cast(op->getOperand(0).getType()).getRank() + : cast(op->getOperand(1).getType()).getRank(); + ArrayRef dimExpr = permutationMap.getResults(); bool find = false; - for (auto &expr : dimExpr) { + for (const auto &expr : dimExpr) if (isLastDim(expr, rank)) { find = true; + break; } - } + return find; } LDBG("The operation is not a read or write operation." << *op << "\n"); @@ -539,48 +533,6 @@ bool isReadWriteOnLastDim(Operation *op) { return false; } -// std::variant numeric_zero(Type type) { -// Type t1 = getElementTypeOrSelf(type); -// if (t1.isF32()) { -// return 0.f; -// } else if (t1.isBF16()) { -// return bfloat2float(float2bfloat(0.f)); -// } else if (t1.isF16()) { -// return half2float(float2half(0.f)); -// } else if (t1.isSignedInteger(8)) { -// return int64_t(0); -// } else if (t1.isSignedInteger(32)) { -// return int64_t(0); -// } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { -// return int64_t(0); -// } else { -// LDBG("Unsupported data type: " << t1 << "\n"); -// assert(0 && "unsupported data type"); -// return (int64_t)0; -// } -// } - -// std::variant numeric_one(Type type) { -// Type t1 = getElementTypeOrSelf(type); -// if (t1.isF32()) { -// return 1.f; -// } else if (t1.isBF16()) { -// return bfloat2float(float2bfloat(1.f)); -// } else if (t1.isF16()) { -// return half2float(float2half(1.f)); -// } else if (t1.isSignedInteger(8)) { -// return int64_t(1); -// } else if (t1.isSignedInteger(32)) { -// return int64_t(1); -// } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { -// return int64_t(1); -// } else { -// LDBG("Unsupported data type: " << t1 << "\n"); -// assert(0 && "unsupported data type"); -// return (int64_t)1; -// } -// } - std::variant numeric_limits_minimum(Type type) { Type t1 = getElementTypeOrSelf(type); if (t1.isF32()) { @@ -627,7 +579,7 @@ std::variant numericLimitsMaximum(Type type) { } } -template +template T getInitValForReduce(vector::CombiningKind kind, Type t) { T result; Type t1 = getElementTypeOrSelf(t); @@ -704,11 +656,10 @@ void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, scf::YieldOp maybeYieldValue(OpBuilder &b, Location loc, const ValueRange &value) { bool hasRetVal = !value.empty(); - if (hasRetVal) { + if (hasRetVal) return b.create(loc, value); - } else { + else return b.create(loc); - } } Type getScalarType(Operation *op) { @@ -726,7 +677,7 @@ Type getScalarType(Operation *op) { Operation *createTensorEmptyBefore(Operation *op) { - auto rtType = dyn_cast(op->getResultTypes()[0]); + auto rtType = cast(op->getResultTypes()[0]); IRRewriter reWriter(op); Block *block = op->getBlock(); @@ -736,10 +687,9 @@ Operation *createTensorEmptyBefore(Operation *op) { SmallVector dynDims; for (unsigned i = 0; i < rtType.getRank(); i++) { shapes.push_back(rtType.getDimSize(i)); - if (rtType.isDynamicDim(i)) { + if (rtType.isDynamicDim(i)) dynDims.push_back( reWriter.create(op->getLoc(), op->getResult(0), i)); - } } auto emtpyOp = reWriter.create( op->getLoc(), rtType.getShape(), rtType.getElementType(), dynDims); @@ -751,18 +701,18 @@ Value getOperationResultTensor( Operation *op, DenseMap &visitedOperation) { OpResult result = op->getResults()[0]; for (Operation *x : result.getUsers()) { - if (!isa(x)) { + if (not isa(x)) continue; - } + Value sourceTensor = x->getOperands()[1]; Operation *srcOp = sourceTensor.getDefiningOp(); - if (!visitedOperation.contains(srcOp)) { + if (not visitedOperation.contains(srcOp)) continue; - } + size_t pos = visitedOperation[srcOp]; - if (pos > visitedOperation[op]) { + if (pos > visitedOperation[op]) continue; - } + return sourceTensor; } LDBG("Result not write back to tensor."); @@ -771,13 +721,12 @@ Value getOperationResultTensor( } Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { - auto rtType = mlir::dyn_cast(op->getResultTypes()[0]); - auto rank = rtType.getRank(); - auto dstType = mlir::dyn_cast(dest.getType()); + auto rtType = cast(op->getResultTypes()[0]); + int64_t rank = rtType.getRank(); + auto dstType = cast(dest.getType()); IRRewriter reWriter(op); - auto zero = - reWriter.create(reWriter.getUnknownLoc(), 0); + auto zero = reWriter.create(op->getLoc(), 0); reWriter.setInsertionPointAfter(op); SmallVector inBoundsVal(rank, true); @@ -786,13 +735,12 @@ Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { SmallVector dynDims; for (unsigned i = 0; i < rtType.getRank(); i++) { shapes.push_back(rtType.getDimSize(i)); - if (rtType.isDynamicDim(i)) { - dynDims.push_back(reWriter.create(reWriter.getUnknownLoc(), - op->getResult(0), i)); - } + if (rtType.isDynamicDim(i)) + dynDims.push_back( + reWriter.create(op->getLoc(), op->getResult(0), i)); } return reWriter.create( - reWriter.getUnknownLoc(), + op->getLoc(), /*vector=*/op->getResult(0), /*source=*/dest, /*indices=*/SmallVector(dstType.getRank(), zero), @@ -827,23 +775,22 @@ Operation *CanonicalizerCommonUsedData::createTransferReadOpBefore( permutationMap[t] = srcReadOpAffineMap; getFusionStrategy().getOpAnchorPos()[t] = t.getVectorType().getRank() - 1; - return t; - } else { - SmallVector inBoundsVal(operandType.getRank(), true); - auto t = rewriter.create( - op->getLoc(), - /*vectorType=*/ - VectorType::get(operandType.getShape(), operandType.getElementType()), - /*source=*/operand, - /*indices=*/SmallVector(operandType.getRank(), zero), - /**affinemap*/ padValue, - /*inBounds=*/inBoundsVal); - DenseMap &permutationMap = getOpPermuationMap(); - permutationMap[t] = t.getPermutationMap(); - getFusionStrategy().getOpAnchorPos()[t] = t.getVectorType().getRank() - 1; - return t; } + SmallVector inBoundsVal(operandType.getRank(), true); + auto t = rewriter.create( + op->getLoc(), + /*vectorType=*/ + VectorType::get(operandType.getShape(), operandType.getElementType()), + /*source=*/operand, + /*indices=*/SmallVector(operandType.getRank(), zero), + /**affinemap*/ padValue, + /*inBounds=*/inBoundsVal); + DenseMap &permutationMap = getOpPermuationMap(); + permutationMap[t] = t.getPermutationMap(); + getFusionStrategy().getOpAnchorPos()[t] = t.getVectorType().getRank() - 1; + + return t; } // canonicalizing operation as tensor empty and transfer write the operation @@ -861,7 +808,6 @@ canonicalizeSourceOperation(Operation *op, vector::TransferReadOp *srcReadOp) { // transfer_read operation auto readOp = createTransferReadOpBefore(op, transferReadOperand, srcReadOp); - op->setOperand(operandIdx, readOp->getResults()[0]); return readOp->getResults()[0]; } @@ -880,12 +826,12 @@ void getOpSourceOps(Operation *op, DenseSet &srcOps) { DenseSet visited; visited.insert(op); while (!srcOperandsQueue.empty()) { - auto accOperand = srcOperandsQueue.front(); + Value accOperand = srcOperandsQueue.front(); srcOperandsQueue.pop_front(); auto accOperandOp = accOperand.getDefiningOp(); - if (!accOperandOp or visited.count(accOperandOp)) { + if (!accOperandOp or visited.count(accOperandOp)) continue; - } + visited.insert(accOperandOp); srcOps.insert(accOperandOp); auto accOperandOperands = accOperandOp->getOperands(); @@ -922,15 +868,14 @@ void getReductionInitAttr(vector::MultiDimReductionOp &multiReductionOp, Attribute &initValueAttr) { auto vecType = multiReductionOp.getSourceVectorType(); auto resultElementType = vecType.getElementType(); - if (isa(resultElementType)) { + if (isa(resultElementType)) initValueAttr = FloatAttr::get( resultElementType, getInitValForReduce(multiReductionOp.getKind(), vecType)); - } else { + else initValueAttr = IntegerAttr::get( resultElementType, getInitValForReduce(multiReductionOp.getKind(), vecType)); - } } void classifySourceRelatedOps(std::queue &accRelatedOps, @@ -942,11 +887,10 @@ void classifySourceRelatedOps(std::queue &accRelatedOps, while (!prevOps.empty()) { auto op = prevOps.front(); prevOps.pop(); - if (isSrcRelated(srcOps, op) or op == srcOp) { + if (isSrcRelated(srcOps, op) or op == srcOp) sourceRelatedOps.push(op); - } else { + else accRelatedOps.push(op); - } } } @@ -960,11 +904,10 @@ void classifyAccRelatedOps(std::queue &accRelatedOps, while (!prevOps.empty()) { auto op = prevOps.front(); prevOps.pop(); - if (isSrcRelated(srcOpsSet, op) or op == srcOp) { + if (isSrcRelated(srcOpsSet, op) or op == srcOp) accRelatedOps.push(op); - } else { + else sourceRelatedOps.push(op); - } } } @@ -972,12 +915,11 @@ void updateReduceReadWriteOperationOperand( const SmallVector &inductionVars, const SmallVector ¶llelAxis, Operation *op, MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel) { - int indiceOffset = mlir::isa(op) ? 1 : 2; + int indiceOffset = isa(op) ? 1 : 2; for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { - if (rdKind == MultiReduceOpAxisKind::Parallel && - idx >= parallelAxis.size()) { + if (rdKind == MultiReduceOpAxisKind::Parallel && idx >= parallelAxis.size()) break; - } + op->setOperand(idx + indiceOffset, inductionVar); } } @@ -992,8 +934,8 @@ vector::TransferReadOp ForLoopGenerator::cloneReductionTransferRead( assert(readOp && " Not transfer_read operation. Current multireduction " "operation may have wrong analysis IR."); - auto clonedOp = b.clone(*readOp, readMap); - auto newReadOp = mlir::dyn_cast(clonedOp); + Operation *clonedOp = b.clone(*readOp, readMap); + auto newReadOp = cast(clonedOp); updateReduceReadWriteOperationOperand(inductionVars, parallelAxis, newReadOp, rdKind); @@ -1017,8 +959,7 @@ makeNewTransferWriteOp(Value source, IRMapping &writeMap, OpBuilder &b, SmallVector &inductionVars) { IRRewriter bodyRewriter(b); auto writeOp = source.getDefiningOp(); - auto newWriteOp = - mlir::dyn_cast(b.clone(*writeOp, writeMap)); + auto newWriteOp = cast(b.clone(*writeOp, writeMap)); updateReduceReadWriteOperationOperand(inductionVars, parallelAxis, newWriteOp, MultiReduceOpAxisKind::Parallel); setOpVectorizationPermutationMap( @@ -1037,7 +978,7 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, void ForLoopGenerator::moveOperationsToCurrentForBody( const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, - const DenseMap &operandIdxMap, const ValueRange &loopState, + const DenseMap &operandIdxMap, ValueRange loopState, DenseMap &originalOperandLoopArgsMap, std::queue &opQueue, DenseMap> &indiceLoopMap) { @@ -1054,23 +995,6 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( } } -bool hasOtherOperations(const std::queue &opQ, - const Operation *multiReductionOp) { - bool res = false; - if (!opQ.empty()) { - std::queue tempQ(opQ); - while (!tempQ.empty()) { - auto cur = tempQ.front(); - tempQ.pop(); - if (!isReadOrWriteOperation(cur) and cur != multiReductionOp) { - res = true; - break; - } - } - } - return res; -}; - void ForLoopGenerator::getResultInCurrentOps( const size_t anchorIdx, const size_t groupId, const std::queue ops, SmallVector &results, @@ -1502,15 +1426,10 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( /// This function also call reduction axis for loop scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, - DenseMap ¤tLoopStateIdxMap, const ValueRange &initArgs, - SmallVector &nextAnchorResults, - DenseMap &nextAnchorResultsIdxMap, - SmallVector &inductionVars, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap, - DenseMap &forResultOrignalResultMap, - DenseMap> &indiceLoopMap) { - auto &rdCanonicalizer = getMultiRdCanonicalizers()[groupIdx]; + DenseMap> &indiceLoopMap, + GenerateLoopHelper &loopHelperParam) { + MultiReductionCanonicalizer &rdCanonicalizer = + getMultiRdCanonicalizers()[groupIdx]; vector::MultiDimReductionOp &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; VectorType vectorType = rdCanonicalizer.getSourceType(); @@ -1528,17 +1447,17 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( // last dim reduction need to a generate dim=16 loop for fused with pre-op int dimSize = 0; - if (parallelIdx == parallelAxis.size()) { + if (parallelIdx == parallelAxis.size()) dimSize = getFusionStrategy().getGroupMaxSteps()[groupIdx]; - } else { + else dimSize = vectorType.getShape()[parallelAxis[parallelIdx]]; - } + Value numIter = makeIndexArithConstantOp(opBuilder, loc, dimSize); // Create a loop and move vectorized operation into loops. return opBuilder.create( - loc, zero, numIter, forSteps, initArgs, + loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { - inductionVars.emplace_back(iv); + loopHelperParam.inductionVars.emplace_back(iv); VectorFusionStrategy &fusionStrategy = getFusionStrategy(); DenseMap &opIndexMap = fusionStrategy.getOpGroupIndexMap(); @@ -1558,14 +1477,20 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( SmallVector nextAnchorArgs; std::queue movedQueue; DenseMap currentOriginalOperandMap = - originalOperandLoopArgsMap; + loopHelperParam.originalOperandLoopArgsMap; DenseMap currentOperandOriginalMap = - loopArgsOriginalOperandMap; + loopHelperParam.loopArgsOriginalOperandMap; + DenseMap currentLoopStateIdxMap = + loopHelperParam.currentLoopStateIdxMap; + movePreOpToCurrentAnchor( - parallelIdx, groupIdx, b, inductionVars, loopState, - currentLoopStateIdxMap, nextAnchorArgsIdxMap, nextAnchorArgs, - opQueue, movedQueue, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, indiceLoopMap); + parallelIdx, groupIdx, b, loopHelperParam.inductionVars, + loopState, loopHelperParam.currentLoopStateIdxMap, + nextAnchorArgsIdxMap, nextAnchorArgs, opQueue, movedQueue, + loopHelperParam.originalOperandLoopArgsMap, + loopHelperParam.loopArgsOriginalOperandMap, indiceLoopMap); + loopHelperParam.loopIterArgs = nextAnchorArgs; + loopHelperParam.currentLoopStateIdxMap = nextAnchorArgsIdxMap; if (parallelIdx == parallelAxis.size() - 1) { // Ensure accumalate expression appear in this parallel anchor @@ -1596,10 +1521,13 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( if (!argsSet.contains(accInitVal)) { nextAnchorArgs.emplace_back(accInitVal); nextAnchorArgsIdxMap[accInitVal] = nextAnchorArgs.size() - 1; - loopArgsOriginalOperandMap[accInitVal] = multiReductionAcc; - originalOperandLoopArgsMap[multiReductionAcc] = accInitVal; + loopHelperParam.loopArgsOriginalOperandMap[accInitVal] = + multiReductionAcc; + loopHelperParam.originalOperandLoopArgsMap[multiReductionAcc] = + accInitVal; } - + loopHelperParam.loopIterArgs = nextAnchorArgs; + loopHelperParam.nextAnchorResultsIdxMap = nextAnchorArgsIdxMap; } else { llvm::llvm_unreachable_internal( "Wrong accumualte source value. Because " @@ -1611,45 +1539,43 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( DenseMap originalResultForResultMap; // 2. generate next for loop if (rdCanonicalizer.hasLastDimReduction() or - parallelIdx < parallelAxis.size() - 1) { + parallelIdx < parallelAxis.size() - 1) nxtFor = parallelAxisGenerateForLoop( - b, groupIdx, parallelIdx + 1, nextAnchorArgsIdxMap, - nextAnchorArgs, nextAnchorResults, nextAnchorResultsIdxMap, - inductionVars, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, forResultOrignalResultMap, - indiceLoopMap); - } else if (parallelAxis.size() - 1 == parallelIdx) { - + b, groupIdx, parallelIdx + 1, indiceLoopMap, loopHelperParam); + else if (parallelAxis.size() - 1 == parallelIdx) nxtFor = reductionAxisGenerateForLoop( b, groupIdx, 0, parallelIdx + 1, nextAnchorArgsIdxMap, - nextAnchorArgs, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, nextAnchorResults, - nextAnchorResultsIdxMap, inductionVars, - forResultOrignalResultMap, originalResultForResultMap, - indiceLoopMap); - } + nextAnchorArgs, loopHelperParam.originalOperandLoopArgsMap, + loopHelperParam.loopArgsOriginalOperandMap, + loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + loopHelperParam.inductionVars, + loopHelperParam.nextAnchorResultOrignalResultMap, + originalResultForResultMap, indiceLoopMap); // 3. move postOp to current body movePostOpToCurrentAnchor( b, parallelIdx, groupIdx, nxtFor->getResults(), - nxtFor->getBlock(), opQueue, movedQueue, inductionVars, - currentLoopStateIdxMap, loopState, currentOriginalOperandMap, - currentOperandOriginalMap, nextAnchorResults, - forResultOrignalResultMap, indiceLoopMap); + nxtFor->getBlock(), opQueue, movedQueue, + loopHelperParam.inductionVars, currentLoopStateIdxMap, loopState, + currentOriginalOperandMap, currentOperandOriginalMap, + loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultOrignalResultMap, indiceLoopMap); // 4. generate loop results - generateLoopResults(b, loc, parallelIdx, groupIdx, nextAnchorResults, - nextAnchorResultsIdxMap, nxtFor->getResults(), - movedQueue, forResultOrignalResultMap, loopState, - currentOperandOriginalMap, nextAnchorArgsIdxMap); - maybeYieldValue(b, loc, nextAnchorResults); + generateLoopResults( + b, loc, parallelIdx, groupIdx, loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, nxtFor->getResults(), + movedQueue, loopHelperParam.nextAnchorResultOrignalResultMap, + loopState, currentOperandOriginalMap, nextAnchorArgsIdxMap); + maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); } else if (parallelIdx == parallelAxis.size()) { DenseMap tmpOriginOperandLoopArgsMap = - originalOperandLoopArgsMap; + loopHelperParam.originalOperandLoopArgsMap; DenseMap tmpLoopArgsOriginalOperandMap = - loopArgsOriginalOperandMap; + loopHelperParam.loopArgsOriginalOperandMap; // get accumualte value Attribute initValueAttr; @@ -1668,18 +1594,21 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( SmallVector argsArray; argsArray.emplace_back(accVal); localAnchorArgsIdxMap[accVal] = 0; - size_t accLoopStateIdx = currentLoopStateIdxMap - [originalOperandLoopArgsMap[multiReductionAcc]]; + size_t accLoopStateIdx = + loopHelperParam.currentLoopStateIdxMap + [loopHelperParam + .originalOperandLoopArgsMap[multiReductionAcc]]; localLoopArgsOriginalOperandMap[accVal] = multiReductionAcc; localOriginalOperandLoopArgsMap[multiReductionAcc] = accVal; for (auto [idx, x] : llvm::enumerate(loopState)) { - if (idx == accLoopStateIdx) { + if (idx == accLoopStateIdx) continue; - } + argsArray.emplace_back(x); localAnchorArgsIdxMap[x] = argsArray.size() - 1; - Value originalValue = loopArgsOriginalOperandMap[initArgs[idx]]; + Value originalValue = loopHelperParam.loopArgsOriginalOperandMap + [loopHelperParam.loopIterArgs[idx]]; localOriginalOperandLoopArgsMap[originalValue] = x; localLoopArgsOriginalOperandMap[x] = originalValue; } @@ -1687,14 +1616,16 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( auto nxtFor = reductionAxisGenerateForLoop( b, groupIdx, 0, parallelIdx, localAnchorArgsIdxMap, argsArray, localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap, - nextAnchorResults, nextAnchorResultsIdxMap, inductionVars, - forResultOrignalResultMap, originalResultForResultMap, - indiceLoopMap); + loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + loopHelperParam.inductionVars, + loopHelperParam.nextAnchorResultOrignalResultMap, + originalResultForResultMap, indiceLoopMap); // insert accumulate value to original vector Value nxtForAccVal = originalResultForResultMap[multiReductionOp->getResults()[0]]; - size_t accIdx = nextAnchorResultsIdxMap[nxtForAccVal]; + size_t accIdx = loopHelperParam.nextAnchorResultsIdxMap[nxtForAccVal]; auto accRes = nxtFor->getResults()[accIdx]; Operation *reductionOp = b.create( @@ -1712,30 +1643,29 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( currentResultMap[insertOp->getResults()[0]] = multiReductionOp->getResults()[0]; currentResultIdxMap[insertOp->getResults()[0]] = accLoopStateIdx; - for (auto [idx, x] : llvm::enumerate(nextAnchorResults)) { - if (forResultOrignalResultMap[x] == - multiReductionOp->getResults()[0]) { + for (auto [idx, x] : + llvm::enumerate(loopHelperParam.nextAnchorResults)) { + if (loopHelperParam.nextAnchorResultOrignalResultMap[x] == + multiReductionOp->getResults()[0]) continue; - } - Value originalResult = forResultOrignalResultMap[x]; - size_t itrIdx = currentLoopStateIdxMap - [tmpOriginOperandLoopArgsMap[originalResult]]; + + Value originalResult = + loopHelperParam.nextAnchorResultOrignalResultMap[x]; + size_t itrIdx = loopHelperParam.currentLoopStateIdxMap + [tmpOriginOperandLoopArgsMap[originalResult]]; currentAnchorResults[itrIdx] = nxtFor->getResults()[idx]; currentResultIdxMap[nxtFor->getResults()[idx]] = itrIdx; currentResultMap[nxtFor->getResults()[idx]] = originalResult; } - nextAnchorResults.clear(); - forResultOrignalResultMap.clear(); - nextAnchorResultsIdxMap.clear(); - nextAnchorResults = std::move(currentAnchorResults); - forResultOrignalResultMap = std::move(currentResultMap); - nextAnchorResultsIdxMap = std::move(currentResultIdxMap); - // std::cout << "next anchor results : " << nextAnchorResults.size() - // << std::endl; - // for (auto x : nextAnchorResults) { - // x.getDefiningOp()->dump(); - // } - maybeYieldValue(b, loc, nextAnchorResults); + loopHelperParam.nextAnchorResults.clear(); + loopHelperParam.nextAnchorResultOrignalResultMap.clear(); + loopHelperParam.nextAnchorResultsIdxMap.clear(); + loopHelperParam.nextAnchorResults = std::move(currentAnchorResults); + loopHelperParam.nextAnchorResultOrignalResultMap = + std::move(currentResultMap); + loopHelperParam.nextAnchorResultsIdxMap = + std::move(currentResultIdxMap); + maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); } }); } @@ -1905,17 +1835,11 @@ void ForLoopGenerator::replaceOpUsersWithForLoopResult( scf::ForOp ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { - MultiReductionCanonicalizer &rdCanonicalizer = - getMultiRdCanonicalizers()[grpIdx]; - vector::MultiDimReductionOp multiReductionOp = - rdCanonicalizer.getCandidateOps()[0]; DenseMap> indiceLoopMap; rearrageMultiReductionIR(grpIdx, indiceLoopMap); - // get current loop init args - DenseMap currentLoopStateIdxMap; - DenseMap nextAnchorResultsIdxMap; + DenseMap currentLoopStateIdxMap, nextAnchorResultsIdxMap; // map original operation operand with loop args DenseMap originalOperandLoopArgsMap, loopArgsOriginalOperandMap, forResultOrignalResultMap; @@ -1923,20 +1847,26 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { SmallVector initArgs; prepareForLoopArgs(grpIdx, currentLoopStateIdxMap, originalOperandLoopArgsMap, loopArgsOriginalOperandMap, initArgs); + GenerateLoopHelper loopHelper; + loopHelper.loopIterArgs = initArgs; + loopHelper.originalOperandLoopArgsMap = originalOperandLoopArgsMap; + loopHelper.loopArgsOriginalOperandMap = loopArgsOriginalOperandMap; + loopHelper.currentLoopStateIdxMap = currentLoopStateIdxMap; + + MultiReductionCanonicalizer &rdCanonicalizer = + getMultiRdCanonicalizers()[grpIdx]; - SmallVector inductionVars; OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); - SmallVector nextAnchorResults; - scf::ForOp forOp = parallelAxisGenerateForLoop( - opBuilder, grpIdx, 0, currentLoopStateIdxMap, initArgs, nextAnchorResults, - nextAnchorResultsIdxMap, inductionVars, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, forResultOrignalResultMap, indiceLoopMap); - replaceOpUsersWithForLoopResult(forOp, grpIdx, nextAnchorResults, - nextAnchorResultsIdxMap, - forResultOrignalResultMap); + scf::ForOp forOp = parallelAxisGenerateForLoop(opBuilder, grpIdx, 0, + indiceLoopMap, loopHelper); + replaceOpUsersWithForLoopResult(forOp, grpIdx, loopHelper.nextAnchorResults, + loopHelper.nextAnchorResultsIdxMap, + loopHelper.nextAnchorResultOrignalResultMap); IRRewriter rewriter(func); + vector::MultiDimReductionOp multiReductionOp = + rdCanonicalizer.getCandidateOps()[0]; rewriter.eraseOp(multiReductionOp); return forOp; @@ -2746,13 +2676,13 @@ void ForLoopGenerator::setOperationCorrectOperand( const DenseMap &opPermuationMap, DenseMap> &indiceloopMap) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { - if (!originalOperandLoopArgsMap.contains(opd)) { + if (not originalOperandLoopArgsMap.contains(opd)) continue; - } + Value loopArg = originalOperandLoopArgsMap[opd]; - if (!operandIdxMap.contains(loopArg)) { + if (not operandIdxMap.contains(loopArg)) continue; - } + op->setOperand(idx, iterArgs[operandIdxMap.at(loopArg)]); } int offset = isa(op) ? 2 : 1; @@ -2764,33 +2694,31 @@ void ForLoopGenerator::setOperationCorrectOperand( auto dimExpr = permutationMap.getResults(); for (auto [idx, x] : llvm::enumerate(dimExpr)) { - if (!isa(x)) { + if (not isa(x)) llvm::llvm_unreachable_internal( "Permuatation map must contains dim expr."); - } size_t dim; - if (auto d = dyn_cast(x)) { + if (auto d = dyn_cast(x)) dim = d.getPosition(); - } - if (auto d = dyn_cast(x)) { + + if (auto d = dyn_cast(x)) dim = d.getValue(); - } + ShapedType tensorType = cast(op->getOperandTypes()[offset - 1]); size_t varIdx = dim; if (tensorType.getRank() > (int64_t)inductionVars.size()) { int64_t tensorOffset = tensorType.getRank() - inductionVars.size(); - if (dim < tensorOffset) { + if (dim < tensorOffset) continue; - } + varIdx = dim - tensorOffset; } - if (indiceloopMap.contains(op)) { + if (indiceloopMap.contains(op)) op->setOperand(dim + offset, inductionVars[indiceloopMap[op][varIdx]]); - } else { + else op->setOperand(dim + offset, inductionVars[varIdx]); - } } if (auto readOp = dyn_cast(op)) { size_t grpIdx = getFusionStrategy().getOpGroupIndexMap()[op]; @@ -3438,7 +3366,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( Operation *srcOp = transferWriteOp->getOperand(0).getDefiningOp(); - if (mlir::isa(srcOp)) { + if (isa(srcOp)) { createNewConstantOp(srcOp, &transferWriteOp, groupSteps); } else { opPermuationMap.insert( diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index 5fcc6b34e..5a723badc 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -410,8 +410,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { return opPermuationMap; } - llvm::SmallVector & - getMultiRdCanonicalizers() { + SmallVector &getMultiRdCanonicalizers() { return multiRdCanonicalizers; } @@ -457,6 +456,19 @@ class CanonicalizerCommonUsedData : public TypeHelper { Operation *getNextTargetOperationInCurrentGroup(Operation *curOp, const size_t grpIdx); }; + +struct GenerateLoopHelper { + DenseMap currentLoopStateIdxMap; + ValueRange loopIterArgs; + SmallVector nextAnchorResults; + DenseMap nextAnchorResultsIdxMap; + DenseMap nextAnchorResultOrignalResultMap; + SmallVector inductionVars; + DenseMap originalOperandLoopArgsMap; + DenseMap loopArgsOriginalOperandMap; + GenerateLoopHelper() = default; +}; + /// generate for loop for each operation. class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { private: @@ -467,6 +479,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { ForLoopGenerator(func::FuncOp &func) : func(func) {} virtual ~ForLoopGenerator() {} + void setGeneratorFunc(func::FuncOp &func) { this->func = func; } void clearCurrentOperationGroup(size_t grpIdx); void generateGroupOpVectorizedIR(const int idx); @@ -516,8 +529,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { void moveOperationsToCurrentForBody( const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, - const llvm::DenseMap &operandIdxMap, - const ValueRange &loopState, + const llvm::DenseMap &operandIdxMap, ValueRange loopState, DenseMap &originalOperandLoopArgsMap, std::queue &queue, DenseMap> &indiceLoopMap); @@ -613,15 +625,8 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { scf::ForOp parallelAxisGenerateForLoop( OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, - llvm::DenseMap ¤tLoopStateIdxMap, - const ValueRange &initArgs, - llvm::SmallVector &nextAnchorResults, - llvm::DenseMap &nextAnchorResultsIdxMap, - llvm::SmallVector &inductionVars, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap, - DenseMap &forResultOrignalResultMap, - DenseMap> &indiceLoopMap); + DenseMap> &indiceLoopMap, + GenerateLoopHelper &loopHelperParam); vector::TransferReadOp cloneReductionTransferRead( Value &source, OpBuilder &b, IRMapping &readMap, diff --git a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir index 4f338951c..eed9c859e 100644 --- a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir +++ b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --split-input-file --fold-tensor-operation --lower-to-tile-vector --CPU-physical-register-pass --mlir-print-ir-after-all | FileCheck %s +// RUN: gc-opt %s --split-input-file --fold-tensor-operation --lower-to-tile-vector --CPU-physical-register-pass | FileCheck %s // CHECK-DAG: #[[map0:.*]] = affine_map<()[s0, s1] -> (s0 * 64 + s1)> From 3db8b185da1facb8c131647459997ed1421ed95e Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 5 Sep 2024 14:43:42 +0800 Subject: [PATCH 37/66] add some comments --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 208 ++++++++---------- lib/gc/Transforms/TilingVector.h | 148 ++++++++----- .../gc/Transforms/cpu-vetor-distribution.mlir | 21 ++ 3 files changed, 209 insertions(+), 168 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index b42e22f47..9c0d3cdad 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -37,7 +37,7 @@ namespace { /// TODO: remove it in the future bool disableSpecialOp = false; bool disableBroadcastOp = false; -bool enableDebugPrinter = false; +bool enableDebugPrinter = true; void printQueue(const std::queue &opQueue) { auto tempQ(opQueue); @@ -641,11 +641,11 @@ void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, auto destAffineMap = AffineMap::get(tensorType.getRank(), 0, affineExprs, rewriter.getContext()); SmallVector inBounds(1, true); - if (mlir::isa(op)) { + if (isa(op)) { auto transferWriteOp = cast(op); transferWriteOp.setPermutationMap(destAffineMap); transferWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); - } else if (mlir::isa(op)) { + } else if (isa(op)) { auto transferReadOp = cast(op); transferReadOp.setPermutationMap(destAffineMap); transferReadOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); @@ -828,7 +828,7 @@ void getOpSourceOps(Operation *op, DenseSet &srcOps) { while (!srcOperandsQueue.empty()) { Value accOperand = srcOperandsQueue.front(); srcOperandsQueue.pop_front(); - auto accOperandOp = accOperand.getDefiningOp(); + Operation *accOperandOp = accOperand.getDefiningOp(); if (!accOperandOp or visited.count(accOperandOp)) continue; @@ -930,7 +930,7 @@ vector::TransferReadOp ForLoopGenerator::cloneReductionTransferRead( SmallVector &inductionVars, bool lastDimReduction, MultiReduceOpAxisKind rdKind) { IRRewriter rewriter(b); - auto readOp = mlir::dyn_cast(source.getDefiningOp()); + auto readOp = dyn_cast(source.getDefiningOp()); assert(readOp && " Not transfer_read operation. Current multireduction " "operation may have wrong analysis IR."); @@ -1832,7 +1832,6 @@ void ForLoopGenerator::replaceOpUsersWithForLoopResult( rectifyGroupOperands(grpIdx, originalResult, forResult); } } - scf::ForOp ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { @@ -2343,19 +2342,19 @@ bool MultiReductionCanonicalizer::hasLastDimReduction() { llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), reductionAxis.end()); bool res = false; - if (reductionAxisSet.contains(typeRank - 1)) { + if (reductionAxisSet.contains(typeRank - 1)) res = true; - } + haslastDimReduction = res; return res; } void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { - if (getCandidateOps().empty()) { + if (getCandidateOps().empty()) return; - } + sourceType = getCandidateOps()[0].getSourceVectorType(); - accType = mlir::dyn_cast(getCandidateOps()[0].getAcc().getType()); + accType = dyn_cast(getCandidateOps()[0].getAcc().getType()); getTypeRank(); getReductionAxisAndParallelAxis(); hasLastDimReduction(); @@ -2369,11 +2368,7 @@ void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { } }; -void TransposeCanonicalizer::prepareSpecialOperationInfo() { - if (getCandidateOps().empty()) { - return; - } -} +void TransposeCanonicalizer::prepareSpecialOperationInfo() {} bool TransposeCanonicalizer::isTransposeOnAllOneDim() { vector::TransposeOp tpOp = getCandidateOps()[0]; @@ -2385,9 +2380,9 @@ bool TransposeCanonicalizer::isTransposeOnAllOneDim() { itrIdx++; continue; } - if (tpVectorType.getShape()[itrIdx] != 1) { + if (tpVectorType.getShape()[itrIdx] != 1) return false; - } + itrIdx++; } return true; @@ -2401,11 +2396,12 @@ bool TransposeCanonicalizer::isTwoDTranspose() { // get the first transpose axis size_t itrIdx = 0; while (itrIdx < rank) { - if ((int64_t)itrIdx != permutation[itrIdx]) { + if ((int64_t)itrIdx != permutation[itrIdx]) diffCount += 1; - } + itrIdx += 1; } + itrIdx = 0; while (itrIdx < rank) { if (permutation[itrIdx] != (int64_t)itrIdx) { @@ -2414,6 +2410,7 @@ bool TransposeCanonicalizer::isTwoDTranspose() { } itrIdx++; } + itrIdx = 0; // get the second transpose axis while (itrIdx < rank) { @@ -2423,13 +2420,14 @@ bool TransposeCanonicalizer::isTwoDTranspose() { } itrIdx++; } - const int tpStep = 16; + + const int tpStep = TRANSPOSE_KERNEL::KERNEL_16X16; VectorType vtType = getCandidateOps()[0].getResultVectorType(); // currently we only support shape that is an integer multiple of tpStep if (vtType.getShape()[getFirstTpIdx()] % tpStep != 0 or - vtType.getShape()[getSecondTpIdx()] % tpStep != 0) { + vtType.getShape()[getSecondTpIdx()] % tpStep != 0) return false; - } + return diffCount == 2; } @@ -2445,9 +2443,9 @@ bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { // Map the index of the larger rank shape to the index of the smaller rank // shape. DenseMap> shapeIdxMap; - for (size_t i = 0; i < smallRankType.getRank(); i++) { + for (size_t i = 0; i < smallRankType.getRank(); i++) shapeIdxMap[i] = std::move(SmallVector()); - } + size_t itrIdx = 0; while (itrIdx < smallRankType.getRank()) { size_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; @@ -2456,16 +2454,16 @@ bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { // skip non corresponding axis // e.g.: vector<32x16x1x32xbf16> -> vector<1x512x32xbf16> while (largeRankType.getShape()[endShape] > - smallRankType.getShape()[itrIdx]) { + smallRankType.getShape()[itrIdx]) endShape++; - } + while (endShape < largeRankType.getRank()) { visitedAxis[endShape] = true; shapeIdxMap[itrIdx].emplace_back(endShape); dimSize *= largeRankType.getShape()[endShape]; - if ((int64_t)dimSize == smallRankType.getShape()[itrIdx]) { + if ((int64_t)dimSize == smallRankType.getShape()[itrIdx]) break; - } + endShape++; } itrIdx++; @@ -2500,9 +2498,9 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { getFusionStrategy().getOpGroups(); for (auto &grp : opGroups) { dummyInitSpecialOperation(); - if (grp.empty()) { + if (grp.empty()) continue; - } + std::queue tempQ(grp); while (!tempQ.empty()) { auto op = tempQ.front(); @@ -2531,6 +2529,17 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { } } +template +void CanonicalizerVectorOperation::processSpecialOperation( + T &canonicalizers, std::function generateFunc) { + for (auto [groupId, canonicalizer] : llvm::enumerate(canonicalizers)) { + SmallVector &ops = canonicalizer.getCandidateOps(); + if (!ops.empty()) + // generate MultiReduction for loops + generateFunc(groupId); + } +} + void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { OpBuilder::InsertionGuard guard(rewriter); @@ -2538,35 +2547,25 @@ void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { // traverse all groups llvm::SmallVector &multiRdCanonicalizers = getMultiRdCanonicalizers(); - llvm::SmallVector &transposeCanonicalizers = + processSpecialOperation, + vector::MultiDimReductionOp>( + multiRdCanonicalizers, [this](const size_t grpIdx) { + (void)generateMultiReductionForLoop(grpIdx); + }); + // generate loop for transpose operation + SmallVector &transposeCanonicalizers = getTransposeCanonicalizers(); - llvm::SmallVector &shapeCastCanonicalizers = + processSpecialOperation, + vector::TransposeOp>( + transposeCanonicalizers, + [this](const size_t grpIdx) { (void)generateTransposeForLoop(grpIdx); }); + // generate loop for shapecast opearation + SmallVector &shapeCastCanonicalizers = getShapeCastCanonicalizers(); - for (auto [groupId, rdCanonicalizer] : - llvm::enumerate(multiRdCanonicalizers)) { - SmallVector &rdOps = - rdCanonicalizer.getCandidateOps(); - if (!rdOps.empty()) { - // generate MultiReduction for loops - (void)generateMultiReductionForLoop(groupId); - } - } - for (auto [groupId, tpCanonicalizer] : - llvm::enumerate(transposeCanonicalizers)) { - SmallVector &transposeOps = - tpCanonicalizer.getCandidateOps(); - if (!transposeOps.empty()) { - (void)generateTransposeForLoop(groupId); - } - } - for (auto [groupId, scCanonicalizer] : - llvm::enumerate(shapeCastCanonicalizers)) { - SmallVector &scOps = - scCanonicalizer.getCandidateOps(); - if (!scOps.empty()) { - (void)generateShapeCastForLoop(groupId); - } - } + processSpecialOperation, + vector::ShapeCastOp>( + shapeCastCanonicalizers, + [this](const size_t grpIdx) { (void)generateShapeCastForLoop(grpIdx); }); } void CanonicalizerVectorOperation::run() { @@ -2669,7 +2668,7 @@ void CanonicalizerVectorOperation::run() { /// void ForLoopGenerator::setOperationCorrectOperand( - Operation *op, const ValueRange &iterArgs, + Operation *op, ValueRange iterArgs, const DenseMap &operandIdxMap, DenseMap &originalOperandLoopArgsMap, ArrayRef inductionVars, @@ -3155,26 +3154,28 @@ bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { Operation *prevOp = nullptr; prevOp = getNotReadWriteOperaiton(tmpQ); if (!prevOp) { + if (opGroups.back().back()->getParentOp() != op->getParentOp()) { + // if previous operation is not in the same block, we need to create a + // group return true; } + if (isSpecialOp(op)) + return true; + return false; } - // not in the same operation - if (prevOp->getParentOp() != op->getParentOp()) { + if (prevOp->getParentOp() != op->getParentOp()) return true; - } // special operation need to check data dependency axis - if (hasDataDependency(prevOp, op)) { + if (hasDataDependency(prevOp, op)) return true; - } // previous operation vector type is not compatible with current operation - if (!isCompatibleVectorType(prevOp, op)) { + if (!isCompatibleVectorType(prevOp, op)) return true; - } } return false; } @@ -3198,9 +3199,9 @@ void VectorFusionStrategy::updateGroupBitgestVectorType(VectorType vectorType) { void VectorFusionStrategy::addOperationToGroup(Operation *op) { assert(op); VectorType vectorType = getOperationMaxVectorType(op).value(); - if (isNeedNewGroup(op)) { + if (isNeedNewGroup(op)) opGroups.emplace_back(std::queue()); - } + if (not isa(op)) { updateGroupBitgestVectorType(vectorType); while (not noNeedToJudgeOps.empty()) { @@ -3219,18 +3220,18 @@ void VectorFusionStrategy::addOperationToGroup(Operation *op) { // of in the same group have no data dependencies. Those operations can generate // a same outter for loop. void VectorFusionStrategy::classifyOperations() { - if (opGroups.empty()) { - // dummpy + // dummpy + if (opGroups.empty()) opGroups.emplace_back(std::queue()); - } + func->walk([&](Operation *op) { if (filterOperation(op)) { addOperationToGroup(op); return WalkResult::advance(); } - if (isNotNeedToProcessOp(op) and !opGroups.back().empty()) { + if (isNotNeedToProcessOp(op) and !opGroups.back().empty()) opGroups.emplace_back(std::queue()); - } + return WalkResult::advance(); }); } @@ -3249,15 +3250,14 @@ Value setOutGroupOperationOperandResult(Operation *op, if (isa(value)) { auto valueType = mlir::dyn_cast(value); if (valueType.isSplat()) { - if (mlir::isa(valueType.getElementType())) { + if (mlir::isa(valueType.getElementType())) initValueAttr = FloatAttr::get( resultElementType, valueType.getSplatValue().convertToDouble()); - } else { + else initValueAttr = IntegerAttr::get( resultElementType, valueType.getSplatValue().getSExtValue()); - } } else { // write original vector into tensor // then we transfer_read from the tensor @@ -3283,7 +3283,7 @@ Value setOutGroupOperationOperandResult(Operation *op, void setOperationOperandResult(Operation *op, const VectorType &newOperandType, const DenseMap &opMap) { for (auto [idx, x] : llvm::enumerate(op->getOperands())) { - if (mlir::dyn_cast(x.getType())) { + if (dyn_cast(x.getType())) { if (!opMap.contains(x.getDefiningOp())) { auto result = setOutGroupOperationOperandResult(x.getDefiningOp(), newOperandType); @@ -3293,11 +3293,9 @@ void setOperationOperandResult(Operation *op, const VectorType &newOperandType, } } } - for (auto x : op->getResults()) { - if (mlir::dyn_cast(x.getType())) { + for (auto x : op->getResults()) + if (dyn_cast(x.getType())) x.setType(newOperandType); - } - } }; /// Reimplementation of writing a tensor from a constant of denseElementattr. @@ -3680,10 +3678,9 @@ void VectorOperationAnalyzer::updateReturnResultKind(Operation *sourceOp, std::make_pair(ReturnTypeKind::RT_Both, srcOpAnchor); return; } - if (rtKind == ReturnTypeKind::RT_InGroup) { + if (rtKind == ReturnTypeKind::RT_InGroup) groupOpResults[sourceOpGid][sourceResult] = std::make_pair(rtKind, srcOpAnchor); - } } void VectorOperationAnalyzer::replaceConstantOpAsNewOp(Operation *op, @@ -3723,9 +3720,9 @@ void VectorOperationAnalyzer::replaceConstantOpAsNewOp(Operation *op, if (valueType.isSplat()) { FailureOr res = createArithSplatConstantOp( rewriter, constantOp->getLoc(), valueType, newOperandType); - if (failed(res)) { + if (failed(res)) llvm::llvm_unreachable_internal("Wrong to create constant op."); - } + op->setOperand(operandIdx, res.value()); // transfer read operation just use the constant value to do // calculation, don't need to read. @@ -3837,15 +3834,14 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { bool inSameGroupNeedReturn = !notInSameGroup and OpAnchorPos[sourceOp] > OpAnchorPos[op]; - if (notInSameGroup or outOfGroup or inSameGroupNeedReturn) { + if (notInSameGroup or outOfGroup or inSameGroupNeedReturn) groupOperationNeedReturnResult(sourceOpGid, sourceOp, op, idx, inSameGroupNeedReturn); - } + continue; } - if (isa_and_nonnull(sourceOp)) { + if (isa_and_nonnull(sourceOp)) replaceConstantOpAsNewOp(op, sourceOp, idx); - } } }); analysisEmptyGroup(); @@ -3880,17 +3876,15 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( auto &resultSet = getGroupOpResults(); assert(!resultSet.empty() && "Expected non-empty value"); // prepare for loop iterargs - SmallVector operands, nextLoopResults; + SmallVector operands; + SmallVector nextLoopResults; DenseMap operandIdxMap, resultIdxMap; DenseMap originalOperandMap, operandOriginalMap, forResultOrignalResultMap; - SetVector &initArgs = getGroupOpInitArgs()[groupId]; - for (Value x : initArgs) { - operands.emplace_back(x); - operandIdxMap[x] = operands.size() - 1; - originalOperandMap[x] = x; - operandOriginalMap[x] = x; - } + + prepareForLoopArgs(groupId, operandIdxMap, originalOperandMap, + operandOriginalMap, operands); + ValueRange forIterArgs(operands); ArrayRef shapes = vectorType.getShape(); SmallVector inductionVars; @@ -3900,21 +3894,8 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( 0, groupId, rewriter, rewriter.getUnknownLoc(), forIterArgs, shapes, inductionVars, operandIdxMap, originalOperandMap, operandOriginalMap, nextLoopResults, resultIdxMap, forResultOrignalResultMap, indiceLoopMap); - DenseSet forOpChildOps; - forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); - auto replaceIfFn = [&](OpOperand &use) { - return not forOpChildOps.contains(use.getOwner()); - }; - - for (auto x : nextLoopResults) { - auto originalResult = forResultOrignalResultMap[x]; - rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), - forOp->getResults()[resultIdxMap[x]], - replaceIfFn); - // following group must use the replaced result as operand - rectifyGroupOperands(groupId, originalResult, - forOp->getResults()[resultIdxMap[x]]); - } + replaceOpUsersWithForLoopResult(forOp, groupId, nextLoopResults, resultIdxMap, + forResultOrignalResultMap); return forOp; } @@ -4094,9 +4075,8 @@ void moveSomeInterferenceOperation( // get the position of each operation func->walk([&](Operation *op) { operationPosition[op] = opCounter++; - if (conditionalFunc(op)) { + if (conditionalFunc(op)) candidateOps.emplace_back(op); - } }); moveCandidateOperation(operationPosition, candidateOps); // eliminate some useless operation diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index 5a723badc..d695ae63c 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -74,6 +74,33 @@ struct HardWareInfo { bool favx2 = true; }; +//===----------------------------------------------------------------------===// +// helper function +//===----------------------------------------------------------------------===// +/// Using to avoid too many parameters in function +struct GenerateLoopHelper { + /// loop iteration args index map + DenseMap currentLoopStateIdxMap; + /// loop iteration args + ValueRange loopIterArgs; + /// next loop anchor yield results + SmallVector nextAnchorResults; + /// next loop anchor yield results index map + DenseMap nextAnchorResultsIdxMap; + /// next loop anchor yield results original result map + DenseMap nextAnchorResultOrignalResultMap; + /// loop induction variables + SmallVector inductionVars; + /// original operand with loop args map + DenseMap originalOperandLoopArgsMap; + /// loop args with original operand map + DenseMap loopArgsOriginalOperandMap; + /// record operation's correct loop indice, due to some operation like reduce + /// may need to reorder loop indice + DenseMap> indiceLoopMap; + GenerateLoopHelper() = default; +}; + /// Vector type conversion helper class class TypeHelper { private: @@ -143,18 +170,25 @@ class VectorFusionStrategy : public TypeHelper { VectorFusionStrategy &operator=(VectorFusionStrategy &&) = default; /// Get the map which contains each group vector type which has biggest rank. - llvm::SmallDenseMap &getGroupBiggestRankVectorType() { + llvm::SmallDenseMap & + getGroupBiggestRankVectorType() noexcept { return groupBigestRankVectorType; }; /// Get the operation group obtained by fusion strategy analysis - SmallVector, 8> &getOpGroups() { return opGroups; } + SmallVector, 8> &getOpGroups() noexcept { + return opGroups; + } /// Get the operation belong to which group index map - DenseMap &getOpGroupIndexMap() { + DenseMap &getOpGroupIndexMap() noexcept { return opGroupIndexMap; } /// Get the map contains max steps of each group - llvm::SmallVector &getGroupMaxSteps() { return groupMaxSteps; } - llvm::DenseMap &getOpAnchorPos() { return opAnchorPos; } + llvm::SmallVector &getGroupMaxSteps() noexcept { + return groupMaxSteps; + } + llvm::DenseMap &getOpAnchorPos() noexcept { + return opAnchorPos; + } func::FuncOp &getFunc() { return func; } /// Do fusion strategy @@ -215,14 +249,14 @@ enum class MultiReduceOpAxisKind { Reduction, Parallel }; class MultiReductionCanonicalizer : public SpecialOperationCanonicalizer { private: - llvm::SmallVector reductionAxis, parallelAxis; + SmallVector reductionAxis, parallelAxis; std::queue prevOps, postOps, accRelatedOps, sourceRelatedOps; bool haslastDimReduction = false; bool isStandaloneOp = false; /// empty reduction means that all the reduction axis is 1 bool isEmptyReduction = true; int64_t typeRank = -1; - llvm::SetVector originalOpResults; + SetVector originalOpResults; VectorType sourceType, accType; llvm::SmallDenseMap resultIdxMap; @@ -234,25 +268,33 @@ class MultiReductionCanonicalizer isStandaloneOp = candidateRdOps.size() == 1; prepareSpecialOperationInfo(); }; - virtual ~MultiReductionCanonicalizer() {}; + virtual ~MultiReductionCanonicalizer() noexcept {}; int64_t getTypeRank(); void getReductionAxisAndParallelAxis(); bool hasLastDimReduction(); - bool getIsStandAloneOp() { return isStandaloneOp; } - bool getHasLastDimReduction() { return haslastDimReduction; } - bool getIsEmptyReduction() { return isEmptyReduction; } + bool getIsStandAloneOp() noexcept { return isStandaloneOp; } + bool getHasLastDimReduction() noexcept { return haslastDimReduction; } + bool getIsEmptyReduction() noexcept { return isEmptyReduction; } void initReductionAxis(); void initParallelAxis(); - SmallVector &getReductionAxis() { return reductionAxis; }; - SmallVector &getParallelAxis() { return parallelAxis; }; - std::queue &getPrevOps() { return prevOps; } - std::queue &getPostOps() { return postOps; } - std::queue &getAccRelatedOps() { return accRelatedOps; } - std::queue &getSourceRelatedOps() { return sourceRelatedOps; } - SetVector &getOriginalOpResults() { return originalOpResults; } - VectorType getSourceType() { return sourceType; }; - VectorType getAccType() { return accType; }; - llvm::SmallDenseMap &getResultIdxMap() { return resultIdxMap; } + SmallVector &getReductionAxis() noexcept { + return reductionAxis; + }; + SmallVector &getParallelAxis() noexcept { return parallelAxis; }; + std::queue &getPrevOps() noexcept { return prevOps; } + std::queue &getPostOps() noexcept { return postOps; } + std::queue &getAccRelatedOps() noexcept { return accRelatedOps; } + std::queue &getSourceRelatedOps() noexcept { + return sourceRelatedOps; + } + SetVector &getOriginalOpResults() noexcept { + return originalOpResults; + } + VectorType getSourceType() noexcept { return sourceType; }; + VectorType getAccType() noexcept { return accType; }; + llvm::SmallDenseMap &getResultIdxMap() noexcept { + return resultIdxMap; + } void setResultIdxMap(const llvm::SmallDenseMap &map) { resultIdxMap = map; } @@ -273,7 +315,7 @@ class BroadcastCanonicalizer const llvm::SmallVector &candidateBcOps) : SpecialOperationCanonicalizer( candidateBcOps, SpecialOperationKind::OP_Broadcast) {}; - virtual ~BroadcastCanonicalizer() {} + virtual ~BroadcastCanonicalizer() noexcept {} void prepareSpecialOperationInfo() override {} static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_Broadcast; @@ -290,7 +332,7 @@ class TransposeCanonicalizer const llvm::SmallVector &candidateTpOps) : SpecialOperationCanonicalizer( candidateTpOps, SpecialOperationKind::OP_Transpose) {}; - virtual ~TransposeCanonicalizer() {} + virtual ~TransposeCanonicalizer() noexcept {} void prepareSpecialOperationInfo() override; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose; @@ -299,8 +341,8 @@ class TransposeCanonicalizer KERNEL_16X16 = 16, }; - size_t getFirstTpIdx() { return firstTpIdx; } - size_t getSecondTpIdx() { return secondTpIdx; } + size_t getFirstTpIdx() noexcept { return firstTpIdx; } + size_t getSecondTpIdx() noexcept { return secondTpIdx; } bool isTwoDTranspose(); bool isTransposeOnAllOneDim(); }; @@ -310,7 +352,7 @@ class ShapeCastCanonicalizer private: public: ShapeCastCanonicalizer( - const llvm::SmallVector &candidateScOps) + const SmallVector &candidateScOps) : SpecialOperationCanonicalizer( candidateScOps, SpecialOperationKind::OP_ShapeCast) {}; virtual ~ShapeCastCanonicalizer() {} @@ -361,7 +403,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { llvm::DenseMap &opPermuationMap) : fusionStrategy(fusionStrategy), groupOpResults(groupOpResults), groupOpInitArgs(groupOpInitArgs), opPermuationMap(opPermuationMap) {} - virtual ~CanonicalizerCommonUsedData() {}; + virtual ~CanonicalizerCommonUsedData() noexcept {}; /// Set fusion strategy void setFuseStrategy(VectorFusionStrategy &&strategy) { @@ -380,49 +422,56 @@ class CanonicalizerCommonUsedData : public TypeHelper { } } } + void setGroupOpResults( const SmallVector< llvm::MapVector>, 8> &results) { groupOpResults = std::move(results); } + void setGroupOpIterArgs( const llvm::SmallVector, 8> &initArgs) { groupOpInitArgs = std::move(initArgs); } - void setPermutationMap(const llvm::DenseMap &map) { + + void setPermutationMap(const DenseMap &map) { opPermuationMap = std::move(map); } // get methods - VectorFusionStrategy &getFusionStrategy() { return fusionStrategy; } + VectorFusionStrategy &getFusionStrategy() noexcept { return fusionStrategy; } SmallVector>, 8> & - getGroupOpResults() { + getGroupOpResults() noexcept { return groupOpResults; } - SmallVector, 8> &getGroupOpInitArgs() { + SmallVector, 8> &getGroupOpInitArgs() noexcept { return groupOpInitArgs; } - DenseMap &getOpPermuationMap() { + DenseMap &getOpPermuationMap() noexcept { return opPermuationMap; } - SmallVector &getMultiRdCanonicalizers() { + SmallVector & + getMultiRdCanonicalizers() noexcept { return multiRdCanonicalizers; } - llvm::SmallVector &getBroadcastCanonicalizers() { + llvm::SmallVector & + getBroadcastCanonicalizers() noexcept { return broadcastCanonicalizers; } - llvm::SmallVector &getTransposeCanonicalizers() { + llvm::SmallVector & + getTransposeCanonicalizers() noexcept { return transposeCanonicalizers; } - llvm::SmallVector &getShapeCastCanonicalizers() { + llvm::SmallVector & + getShapeCastCanonicalizers() noexcept { return shapeCastCanonicalizers; } @@ -457,18 +506,6 @@ class CanonicalizerCommonUsedData : public TypeHelper { const size_t grpIdx); }; -struct GenerateLoopHelper { - DenseMap currentLoopStateIdxMap; - ValueRange loopIterArgs; - SmallVector nextAnchorResults; - DenseMap nextAnchorResultsIdxMap; - DenseMap nextAnchorResultOrignalResultMap; - SmallVector inductionVars; - DenseMap originalOperandLoopArgsMap; - DenseMap loopArgsOriginalOperandMap; - GenerateLoopHelper() = default; -}; - /// generate for loop for each operation. class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { private: @@ -478,9 +515,9 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { ForLoopGenerator() = default; ForLoopGenerator(func::FuncOp &func) : func(func) {} - virtual ~ForLoopGenerator() {} + virtual ~ForLoopGenerator() noexcept {} - void setGeneratorFunc(func::FuncOp &func) { this->func = func; } + void setGeneratorFunc(func::FuncOp &func) noexcept { this->func = func; } void clearCurrentOperationGroup(size_t grpIdx); void generateGroupOpVectorizedIR(const int idx); @@ -533,8 +570,9 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { DenseMap &originalOperandLoopArgsMap, std::queue &queue, DenseMap> &indiceLoopMap); + void setOperationCorrectOperand( - Operation *op, const ValueRange &iterArgs, + Operation *op, ValueRange iterArgs, const DenseMap &operandIdxMap, DenseMap &originalOperandLoopArgsMap, ArrayRef inductionVars, @@ -717,7 +755,6 @@ class CanonicalizerVectorOperation : virtual public ForLoopGenerator, private: func::FuncOp func; IRRewriter rewriter; - CanonicalizerKind kind; public: @@ -739,8 +776,11 @@ class CanonicalizerVectorOperation : virtual public ForLoopGenerator, virtual ~CanonicalizerVectorOperation() = default; // get functions - func::FuncOp &getFunc() { return func; }; - IRRewriter &getIRWewriter() { return rewriter; } + func::FuncOp &getFunc() noexcept { return func; }; + IRRewriter &getIRWewriter() noexcept { return rewriter; } + template + void processSpecialOperation(T &canonicalizers, + std::function generateFunc); // void canonicalizeSpecialOperation(); void clearSpecialOperationCanonicalizers(); diff --git a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir index eed9c859e..51e6ef721 100644 --- a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir +++ b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir @@ -596,3 +596,24 @@ func.func @reduce_fuse_test12(%input: tensor<16x32x64xf32>, %1 = linalg.mul ins(%reduce, %reduce : tensor<16x32xf32>, tensor<16x32xf32>) outs(%init: tensor<16x32xf32>) -> tensor<16x32xf32> func.return %1 : tensor<16x32xf32> } + +// func.func @pad_single_test13(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x64x58x58xf32> { +// %cst = arith.constant 0.000000e+00 : f32 +// %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] { +// ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): +// tensor.yield %cst : f32 +// } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32> +// return %padded : tensor<1x64x58x58xf32> +// } + + +// func.func @pad_valid_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x2x58x58x32xf32> { +// %cst = arith.constant 0.000000e+00 : f32 +// %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] { +// ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): +// tensor.yield %cst : f32 +// } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32> +// %0 = tensor.empty() : tensor<1x2x58x58x32xf32> +// %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32> +// return %1 : tensor<1x2x58x58x32xf32> +// } From 0c5e5a44a1b2bfc063b1957b2d75029702b8602c Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 5 Sep 2024 16:29:51 +0800 Subject: [PATCH 38/66] simplify nestedforloop generate --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 134 ++++++++---------- lib/gc/Transforms/TilingVector.h | 22 +-- 2 files changed, 63 insertions(+), 93 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 9c0d3cdad..92e208070 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -1228,13 +1228,12 @@ void ForLoopGenerator::generateLoopResults( nextAnchorResults.clear(); nextAnchorResultsIdxMap.clear(); // reduction operation due to special process results size will be zero - if (results.size() > 0) { + if (results.size() > 0) for (Value x : loopState) { nextAnchorResults.emplace_back(results[nextOperandIdxMap[x]]); nextAnchorResultsIdxMap[results[nextOperandIdxMap[x]]] = nextAnchorResults.size() - 1; } - } forResultOrignalResultMap = std::move(currentResultMap); } @@ -1739,17 +1738,14 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( }); } -void ForLoopGenerator::prepareForLoopArgs( - const size_t grpIdx, DenseMap ¤tLoopStateIdxMap, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap, - SmallVector &loopArgs) { +void ForLoopGenerator::prepareForLoopArgs(const size_t grpIdx, + GenerateLoopHelper &loopHelper) { SetVector &grpArgs = getGroupOpInitArgs()[grpIdx]; - loopArgs.assign(grpArgs.begin(), grpArgs.end()); + loopHelper.loopIterArgs = grpArgs.getArrayRef(); for (auto [idx, val] : llvm::enumerate(grpArgs)) { - currentLoopStateIdxMap[val] = idx; - originalOperandLoopArgsMap[val] = val; - loopArgsOriginalOperandMap[val] = val; + loopHelper.currentLoopStateIdxMap[val] = idx; + loopHelper.originalOperandLoopArgsMap[val] = val; + loopHelper.loopArgsOriginalOperandMap[val] = val; } } @@ -1839,18 +1835,8 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { rearrageMultiReductionIR(grpIdx, indiceLoopMap); // get current loop init args DenseMap currentLoopStateIdxMap, nextAnchorResultsIdxMap; - // map original operation operand with loop args - DenseMap originalOperandLoopArgsMap, loopArgsOriginalOperandMap, - forResultOrignalResultMap; - - SmallVector initArgs; - prepareForLoopArgs(grpIdx, currentLoopStateIdxMap, originalOperandLoopArgsMap, - loopArgsOriginalOperandMap, initArgs); GenerateLoopHelper loopHelper; - loopHelper.loopIterArgs = initArgs; - loopHelper.originalOperandLoopArgsMap = originalOperandLoopArgsMap; - loopHelper.loopArgsOriginalOperandMap = loopArgsOriginalOperandMap; - loopHelper.currentLoopStateIdxMap = currentLoopStateIdxMap; + prepareForLoopArgs(grpIdx, loopHelper); MultiReductionCanonicalizer &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; @@ -2224,8 +2210,12 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { DenseMap originalOperandMap, operandOriginalMap, resultIdxMap, forResultOrignalResultMap; SmallVector iterArgs; - prepareForLoopArgs(grpIdx, operandIdxMap, originalOperandMap, - operandOriginalMap, iterArgs); + GenerateLoopHelper loopHelper; + prepareForLoopArgs(grpIdx, loopHelper); + operandIdxMap = loopHelper.currentLoopStateIdxMap; + originalOperandMap = loopHelper.originalOperandLoopArgsMap; + operandOriginalMap = loopHelper.loopArgsOriginalOperandMap; + iterArgs = loopHelper.loopIterArgs; SmallVector inductionVars; // TODO: need to process transpose on all one dim @@ -2733,15 +2723,8 @@ void ForLoopGenerator::setOperationCorrectOperand( scf::ForOp ForLoopGenerator::constructNestedForOp( const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, - const Location &loc, const ValueRange &iterArgs, - const ArrayRef &dims, SmallVector &inductionVars, - DenseMap &operandIdxMap, - DenseMap &originalOperandMap, - DenseMap &operandOriginalMap, - SmallVector &nextAnchorResults, - DenseMap &nextAnchorResultsIdxMap, - DenseMap &forResultOrignalResultMap, - DenseMap> &indiceLoopMap) { + const Location &loc, ArrayRef dims, + GenerateLoopHelper &loopHelper) { const int loop_step = getFusionStrategy().getGroupMaxSteps()[groupIdx]; // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); @@ -2751,9 +2734,9 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( // Create a loop and move vectorized operation into loops. auto forOp = b.create( - loc, zero, numIter, forSteps, iterArgs, + loc, zero, numIter, forSteps, loopHelper.loopIterArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { - inductionVars.emplace_back(iv); + loopHelper.inductionVars.emplace_back(iv); // inner most body of the loop if (forDimIdx == dims.size() - 1) { @@ -2768,49 +2751,60 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( // 3. move opeartions to current for block moveOperationsToCurrentForBody( - groupIdx, b, inductionVars, operandIdxMap, loopState, - originalOperandMap, movingOperation, indiceLoopMap); + groupIdx, b, loopHelper.inductionVars, + loopHelper.currentLoopStateIdxMap, loopState, + loopHelper.originalOperandLoopArgsMap, movingOperation, + loopHelper.indiceLoopMap); getResultInCurrentOps(forDimIdx, groupIdx, movingOperation, - nextAnchorResults, nextAnchorResultsIdxMap, - forResultOrignalResultMap); - maybeYieldValue(b, loc, nextAnchorResults); + loopHelper.nextAnchorResults, + loopHelper.nextAnchorResultsIdxMap, + loopHelper.nextAnchorResultOrignalResultMap); + maybeYieldValue(b, loc, loopHelper.nextAnchorResults); } else { // outter loop // 1. move pre-Op to current body DenseMap nextAnchorArgsIdxMap; SmallVector nextAnchorArgs; - DenseMap currentOriginalOperandMap = originalOperandMap; - DenseMap currentOperandOriginalMap = operandOriginalMap; + DenseMap currentOriginalOperandMap = + loopHelper.originalOperandLoopArgsMap; + DenseMap currentOperandOriginalMap = + loopHelper.loopArgsOriginalOperandMap; + DenseMap currentArgsIdxMap = + loopHelper.currentLoopStateIdxMap; std::queue movedQueue; std::queue &opQueue = getFusionStrategy().getOpGroups()[groupIdx]; movePreOpToCurrentAnchor( - forDimIdx, groupIdx, b, inductionVars, loopState, operandIdxMap, - nextAnchorArgsIdxMap, nextAnchorArgs, opQueue, movedQueue, - originalOperandMap, operandOriginalMap, indiceLoopMap); + forDimIdx, groupIdx, b, loopHelper.inductionVars, loopState, + loopHelper.currentLoopStateIdxMap, nextAnchorArgsIdxMap, + nextAnchorArgs, opQueue, movedQueue, + loopHelper.originalOperandLoopArgsMap, + loopHelper.loopArgsOriginalOperandMap, loopHelper.indiceLoopMap); - auto nxtFor = constructNestedForOp( - forDimIdx + 1, groupIdx, b, loc, nextAnchorArgs, dims, - inductionVars, nextAnchorArgsIdxMap, originalOperandMap, - operandOriginalMap, nextAnchorResults, nextAnchorResultsIdxMap, - forResultOrignalResultMap, indiceLoopMap); + loopHelper.loopIterArgs = nextAnchorArgs; + loopHelper.currentLoopStateIdxMap = nextAnchorArgsIdxMap; + auto nxtFor = constructNestedForOp(forDimIdx + 1, groupIdx, b, loc, + dims, loopHelper); movePostOpToCurrentAnchor( b, forDimIdx, groupIdx, nxtFor->getResults(), b.getBlock(), - opQueue, movedQueue, inductionVars, operandIdxMap, loopState, - currentOriginalOperandMap, currentOperandOriginalMap, - nextAnchorResults, forResultOrignalResultMap, indiceLoopMap); + opQueue, movedQueue, loopHelper.inductionVars, currentArgsIdxMap, + loopState, currentOriginalOperandMap, currentOperandOriginalMap, + loopHelper.nextAnchorResults, + loopHelper.nextAnchorResultOrignalResultMap, + loopHelper.indiceLoopMap); - generateLoopResults(b, loc, forDimIdx, groupIdx, nextAnchorResults, - nextAnchorResultsIdxMap, nxtFor->getResults(), - movedQueue, forResultOrignalResultMap, loopState, - currentOperandOriginalMap, nextAnchorArgsIdxMap); + generateLoopResults( + b, loc, forDimIdx, groupIdx, loopHelper.nextAnchorResults, + loopHelper.nextAnchorResultsIdxMap, nxtFor->getResults(), + movedQueue, loopHelper.nextAnchorResultOrignalResultMap, + loopState, currentOperandOriginalMap, nextAnchorArgsIdxMap); - maybeYieldValue(b, loc, nextAnchorResults); + maybeYieldValue(b, loc, loopHelper.nextAnchorResults); } }); return forOp; @@ -3873,29 +3867,17 @@ void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( const size_t groupId, IRRewriter &rewriter, VectorType vectorType) { - auto &resultSet = getGroupOpResults(); - assert(!resultSet.empty() && "Expected non-empty value"); // prepare for loop iterargs - SmallVector operands; - SmallVector nextLoopResults; - DenseMap operandIdxMap, resultIdxMap; - DenseMap originalOperandMap, operandOriginalMap, - forResultOrignalResultMap; - - prepareForLoopArgs(groupId, operandIdxMap, originalOperandMap, - operandOriginalMap, operands); + GenerateLoopHelper loopHelper; + prepareForLoopArgs(groupId, loopHelper); - ValueRange forIterArgs(operands); ArrayRef shapes = vectorType.getShape(); - SmallVector inductionVars; - DenseMap> indiceLoopMap; // generate for loop auto forOp = constructNestedForOp( - 0, groupId, rewriter, rewriter.getUnknownLoc(), forIterArgs, shapes, - inductionVars, operandIdxMap, originalOperandMap, operandOriginalMap, - nextLoopResults, resultIdxMap, forResultOrignalResultMap, indiceLoopMap); - replaceOpUsersWithForLoopResult(forOp, groupId, nextLoopResults, resultIdxMap, - forResultOrignalResultMap); + 0, groupId, rewriter, rewriter.getUnknownLoc(), shapes, loopHelper); + replaceOpUsersWithForLoopResult(forOp, groupId, loopHelper.nextAnchorResults, + loopHelper.nextAnchorResultsIdxMap, + loopHelper.nextAnchorResultOrignalResultMap); return forOp; } diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index d695ae63c..eee1d3a12 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -522,11 +522,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { void generateGroupOpVectorizedIR(const int idx); /// prepare for loop iteration args - void prepareForLoopArgs(const size_t grpIdx, - DenseMap ¤tLoopStateIdxMap, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap, - SmallVector &loopArgs); + void prepareForLoopArgs(const size_t grpIdx, GenerateLoopHelper &loopHelper); /// replace original operation result with corresponding for loop result void replaceOpUsersWithForLoopResult( @@ -551,18 +547,10 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { generateVectorizedForLoop(const size_t groupId, IRRewriter &rewriter, const VectorType vectorType); - scf::ForOp constructNestedForOp( - const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, - const Location &loc, const ValueRange &iterArgs, - const llvm::ArrayRef &dims, - llvm::SmallVector &inductionVars, - llvm::DenseMap &operandIdxMap, - DenseMap &originalOperandMap, - DenseMap &operandOriginalMap, - llvm::SmallVector &nextAnchorResults, - DenseMap &nextAnchorResultsIdxMap, - DenseMap &forResultOrignalResultMap, - DenseMap> &indiceLoopMap); + scf::ForOp constructNestedForOp(const size_t forDimIdx, const size_t groupIdx, + OpBuilder &b, const Location &loc, + ArrayRef dims, + GenerateLoopHelper &loopGenerator); void moveOperationsToCurrentForBody( const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, From 8bdf9843f1f8bef19c8c7b53b98a6d80ede0a982 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 6 Sep 2024 21:14:41 +0800 Subject: [PATCH 39/66] fix too many parameters in function --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 773 ++++++++---------- lib/gc/Transforms/TilingVector.h | 181 +++- 2 files changed, 474 insertions(+), 480 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 92e208070..3130760e5 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -37,7 +37,7 @@ namespace { /// TODO: remove it in the future bool disableSpecialOp = false; bool disableBroadcastOp = false; -bool enableDebugPrinter = true; +bool enableDebugPrinter = false; void printQueue(const std::queue &opQueue) { auto tempQ(opQueue); @@ -146,8 +146,73 @@ bool isNotSupportOperation(Operation *op) { vector::MaskedStoreOp, vector::CreateMaskOp>(op); } -/// whether operation is operate on dynamic shape +/// Get vector type of the operation \param op +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationVectorType(Operation *op, + bool isPrevOp = true) { + if (!op) { + return failure(); + } + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + if (auto retType = dyn_cast( + transferWriteOp.getOperandTypes()[0])) + return retType; + + LDBG("TransferWrite Operation has wrong vector to write."); + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + if (isPrevOp) + return cast( + multiReductionOp->getResultTypes()[0]); + + // TODO: may need to add accumulate value vectortype + return cast(multiReductionOp.getSourceVectorType()); + }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (isPrevOp) { + if (op->getResultTypes().empty()) + return failure(); + + if (auto shapedType = + dyn_cast(op->getResultTypes()[0])) + return shapedType; + + return failure(); + } + if (op->getOperandTypes().empty()) + return failure(); + + if (auto shapedType = + dyn_cast(op->getOperandTypes()[0])) { + return shapedType; + } + return failure(); + }); + if (!failed(ret) and isDynamicType(ret.value())) { + return failure(); + } + return ret; +} + +/// whether the vector operation is operate on dynamic shape bool hasDynamicShape(Operation *op) { + if (failed(getOperationVectorType(op))) { + return false; + } auto isDynamicShapedType = [](Value x) { if (auto type = dyn_cast(x.getType())) if (ShapedType::isDynamicShape(type.getShape())) @@ -245,9 +310,8 @@ FailureOr createArithSplatConstantOp(IRRewriter &rewriter, const Location &loc, DenseElementsAttr valueType, VectorType newOperandType) { - if (not valueType.isSplat()) { + if (not valueType.isSplat()) return failure(); - } TypedAttr attr; if (isa(newOperandType.getElementType())) { @@ -258,69 +322,6 @@ FailureOr createArithSplatConstantOp(IRRewriter &rewriter, return rewriter.create(loc, attr)->getResults()[0]; } -/// Get vector type of the operation \param op -/// \param isPrevOp whether the operation is a previous operation, if it is not -/// prev-op, may need to use result vectortype -/// default will return the opeation result type -mlir::FailureOr getOperationVectorType(Operation *op, - bool isPrevOp = true) { - if (!op) { - return failure(); - } - auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; - auto ret = - TypeSwitch>(op) - .Case( - [&](vector::TransferWriteOp transferWriteOp) - -> mlir::FailureOr { - auto retType = - dyn_cast(transferWriteOp.getOperandTypes()[0]); - if (retType) { - return retType; - } - LDBG("TransferWrite Operation has wrong vector to write."); - return failure(); - }) - .Case( - [&](vector::TransferReadOp transferReadOp) - -> mlir::FailureOr { - return transferReadOp.getVectorType(); - }) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - if (isPrevOp) { - return cast( - multiReductionOp->getResultTypes()[0]); - } - // TODO: may need to add accumulate value vectortype - return cast(multiReductionOp.getSourceVectorType()); - }) - .Default([&](Operation *op) -> mlir::FailureOr { - if (isPrevOp) { - if (op->getResultTypes().empty()) { - return failure(); - } - if (auto shapedType = - dyn_cast(op->getResultTypes()[0])) { - return shapedType; - } - return failure(); - } - if (op->getOperandTypes().empty()) { - return failure(); - } - if (auto shapedType = - dyn_cast(op->getOperandTypes()[0])) { - return shapedType; - } - return failure(); - }); - if (!failed(ret) and isDynamicType(ret.value())) { - return failure(); - } - return ret; -} - /// get operation vector type /// \param isPrevOp whether the operation is a previous operation, if it is not /// prev-op, may need to use result vectortype @@ -335,12 +336,9 @@ mlir::FailureOr getOperationMaxVectorType(Operation *op) { .Case( [&](vector::TransferWriteOp transferWriteOp) -> mlir::FailureOr { - auto retType = - cast(transferWriteOp.getOperandTypes()[0]); - if (retType) + if (auto retType = + cast(transferWriteOp.getOperandTypes()[0])) return retType; - - LDBG("TransferWrite Operation has wrong vector to write."); return failure(); }) .Case( @@ -378,7 +376,7 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loop_step) { mlir::FailureOr baseType = getOperationVectorType(op); if (failed(baseType)) { LDBG("Failed to get vector type for operation: " << *op << "\n"); - assert(false && "Failed to get vector type for operation"); + assert(0 && "Failed to get vector type for operation"); return VectorType(); } auto vectorizedType = baseType.value(); @@ -953,22 +951,6 @@ vector::TransferReadOp ForLoopGenerator::cloneReductionTransferRead( return newReadOp; } -vector::TransferWriteOp -makeNewTransferWriteOp(Value source, IRMapping &writeMap, OpBuilder &b, - const SmallVector ¶llelAxis, - SmallVector &inductionVars) { - IRRewriter bodyRewriter(b); - auto writeOp = source.getDefiningOp(); - auto newWriteOp = cast(b.clone(*writeOp, writeMap)); - updateReduceReadWriteOperationOperand(inductionVars, parallelAxis, newWriteOp, - MultiReduceOpAxisKind::Parallel); - setOpVectorizationPermutationMap( - newWriteOp, b, cast(newWriteOp->getResult(0).getType()), - newWriteOp.getPermutationMap()); - bodyRewriter.replaceOp(writeOp, newWriteOp); - return newWriteOp; -} - Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, int64_t x) { return opBuilder.create( @@ -1103,9 +1085,8 @@ void ForLoopGenerator::replaceOperationsWithForLoopResult( while (!tmpQ.empty()) { auto curOp = tmpQ.front(); tmpQ.pop(); - for (auto x : curOp->getOperands()) { + for (auto x : curOp->getOperands()) operationOperands.insert(x); - } } auto replaceIfFn = [&](OpOperand &use) { return operationOperands.contains(use.get()); @@ -1125,73 +1106,65 @@ void ForLoopGenerator::replaceOperationsWithForLoopResult( /// \param [in, out] originalOperandLoopArgsMap /// \param [in, out] LoopArgsoriginalOperandMap void ForLoopGenerator::movePreOpToCurrentAnchor( - const size_t anchorIdx, const size_t groupIdx, OpBuilder &b, - ArrayRef inductionVars, const ValueRange &loopState, - DenseMap ¤tLoopStateIdxMap, - DenseMap &nextAnchorArgsIdxMap, + OpBuilder &b, DenseMap &nextLoopStateIdxMap, SmallVector &nextAnchorArgs, - std::queue &candidateQueue, - std::queue &movedQueue, - DenseMap &originalOperandLoopArgsMap, - DenseMap &LoopArgsOriginalOperandMap, - DenseMap> &indiceLoopMap) { + GenerateLoopHelper &loopHelperParam) { // 1. get operations in current anchor position std::queue movingOperation; - getOperationInCurrentAnchor(anchorIdx, candidateQueue, movingOperation); + getOperationInCurrentAnchor(loopHelperParam.anchorIdx, + *loopHelperParam.candidateOps, movingOperation); // 2. rewrite operation as vectorize IR - rewriteOperationAsVectorize(b, groupIdx, &movingOperation); + rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, &movingOperation); // 3. move opeartions to current for block moveOperationsToCurrentForBody( - groupIdx, b, inductionVars, currentLoopStateIdxMap, loopState, - originalOperandLoopArgsMap, movingOperation, indiceLoopMap); + loopHelperParam.groupIdx, b, loopHelperParam.inductionVars, + loopHelperParam.currentLoopStateIdxMap, loopHelperParam.loopIterArgs, + loopHelperParam.originalOperandLoopArgsMap, movingOperation, + loopHelperParam.indiceLoopMap); // 4. get next anchor args - getInitArgsToNextAnchor(anchorIdx, groupIdx, candidateQueue, loopState, - currentLoopStateIdxMap, nextAnchorArgsIdxMap, - nextAnchorArgs, originalOperandLoopArgsMap, - LoopArgsOriginalOperandMap); + getInitArgsToNextAnchor( + loopHelperParam.anchorIdx, loopHelperParam.groupIdx, + *loopHelperParam.candidateOps, loopHelperParam.loopIterArgs, + loopHelperParam.currentLoopStateIdxMap, nextLoopStateIdxMap, + nextAnchorArgs, loopHelperParam.originalOperandLoopArgsMap, + loopHelperParam.loopArgsOriginalOperandMap); // 5. move operations to moved queue while (!movingOperation.empty()) { - movedQueue.push(movingOperation.front()); + loopHelperParam.movedOps->push(movingOperation.front()); movingOperation.pop(); } } void ForLoopGenerator::movePostOpToCurrentAnchor( - OpBuilder &b, const int anchorIdx, const int groupIdx, - const ValueRange &forResults, const Block *forBlock, - std::queue &candidateOps, std::queue &movedOps, - ArrayRef inductionVars, const DenseMap &operandIdxMap, - const ValueRange &loopState, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap, - const SmallVector &nextAnchorResults, - DenseMap &forResultOrignalResultMap, - DenseMap> &indiceLoopMap) { + OpBuilder &b, GenerateLoopHelper &loopHelperParam) { // 1. move post-op to current loop body std::queue movingOperations; - getOperationInCurrentAnchor(anchorIdx, candidateOps, movingOperations); + getOperationInCurrentAnchor(loopHelperParam.anchorIdx, + *loopHelperParam.candidateOps, movingOperations); + rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, &movingOperations); - rewriteOperationAsVectorize(b, groupIdx, &movingOperations); - - moveOperationsToCurrentForBody(anchorIdx, b, inductionVars, operandIdxMap, - loopState, originalOperandLoopArgsMap, - movingOperations, indiceLoopMap); + moveOperationsToCurrentForBody( + loopHelperParam.anchorIdx, b, loopHelperParam.inductionVars, + loopHelperParam.currentLoopStateIdxMap, loopHelperParam.loopIterArgs, + loopHelperParam.originalOperandLoopArgsMap, movingOperations, + loopHelperParam.indiceLoopMap); // 2. replace correct for loop result to post-op IRRewriter rewriter(b); - replaceOperationsWithForLoopResult(rewriter, forResults, forBlock, - nextAnchorResults, movingOperations, - forResultOrignalResultMap); + replaceOperationsWithForLoopResult( + rewriter, loopHelperParam.forResults, loopHelperParam.forBlock, + loopHelperParam.nextAnchorResults, movingOperations, + loopHelperParam.nextAnchorResultOrignalResultMap); // 3. move operations to moved queue while (!movingOperations.empty()) { - movedOps.push(movingOperations.front()); + loopHelperParam.movedOps->push(movingOperations.front()); movingOperations.pop(); } } @@ -1206,6 +1179,8 @@ void ForLoopGenerator::generateLoopResults( DenseMap &nextOperandIdxMap) { SmallVector results; DenseMap currentResultMap; + llvm::outs() << " move current operation to current for block\n"; + printQueue(movedOperation); getResultInCurrentOps(anchorIdx, groupIdx, movedOperation, results, nextAnchorResultsIdxMap, currentResultMap); @@ -1239,32 +1214,24 @@ void ForLoopGenerator::generateLoopResults( } scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( - OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, - const int anchorIdx, llvm::DenseMap ¤tLoopStateIdxMap, - const ValueRange &initArgs, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap, - llvm::SmallVector &nextAnchorResults, - DenseMap &nextAnchorResultsIdxMap, - llvm::SmallVector &inductionVars, - DenseMap &forResultOrignalResultMap, - DenseMap &originalResultForResultMap, - DenseMap> &indiceLoopMap) { + OpBuilder &opBuilder, const size_t reductionIdx, + GenerateLoopHelper &loopHelperParam) { MultiReductionCanonicalizer rdCanonicalizer = - getMultiRdCanonicalizers()[groupIdx]; + getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; VectorFusionStrategy &fusionStrategy = getFusionStrategy(); SmallVector, 8> &opGroups = fusionStrategy.getOpGroups(); - std::queue &opQueue = opGroups[groupIdx]; + std::queue &opQueue = opGroups[loopHelperParam.groupIdx]; const auto loc = multireductionOp->getLoc(); SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); bool lastDimReduction = rdCanonicalizer.hasLastDimReduction(); VectorType vectorType = rdCanonicalizer.getSourceType(); - const int loopStep = getFusionStrategy().getGroupMaxSteps()[groupIdx]; + const int loopStep = + getFusionStrategy().getGroupMaxSteps()[loopHelperParam.groupIdx]; IRRewriter rewriterOfFunc(func); @@ -1276,9 +1243,11 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( Value numIter = makeIndexArithConstantOp( opBuilder, loc, vectorType.getShape()[reductionAxis[reductionIdx]]); scf::ForOp forOp = opBuilder.create( - loc, zero, numIter, forSteps, initArgs, + loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { - inductionVars.emplace_back(iv); + loopHelperParam.inductionVars.emplace_back(iv); + size_t currentAnchorId = loopHelperParam.anchorIdx; + SmallVector tmpArgs(loopState); if (reductionIdx < reductionAxis.size() - 1) { @@ -1287,66 +1256,78 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( SmallVector nextAnchorArgs; std::queue movedOperation; DenseMap currentoriginalArgsMap = - originalOperandLoopArgsMap; + loopHelperParam.originalOperandLoopArgsMap; DenseMap currentArgsOriginalMap = - loopArgsOriginalOperandMap; + loopHelperParam.loopArgsOriginalOperandMap; + DenseMap currentArgsIdxMap = + loopHelperParam.currentLoopStateIdxMap; DenseMap originalArgsMap, argsOriginalMap; - movePreOpToCurrentAnchor(anchorIdx, groupIdx, b, inductionVars, - loopState, currentLoopStateIdxMap, - nextAnchorArgsIdxMap, nextAnchorArgs, - opQueue, movedOperation, originalArgsMap, - argsOriginalMap, indiceLoopMap); + loopHelperParam.updateDataBeforePreOpMove(tmpArgs, opQueue, + movedOperation); + movePreOpToCurrentAnchor(b, nextAnchorArgsIdxMap, nextAnchorArgs, + loopHelperParam); + loopHelperParam.updateDataAfterPreOpMove(nextAnchorArgsIdxMap, + nextAnchorArgs); // replace reduction init args - if (originalOperandLoopArgsMap.contains(multireductionOp.getAcc())) { - size_t accValIdx = currentLoopStateIdxMap - [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; + if (loopHelperParam.originalOperandLoopArgsMap.contains( + multireductionOp.getAcc())) { + size_t accValIdx = + loopHelperParam.currentLoopStateIdxMap + [loopHelperParam.originalOperandLoopArgsMap[multireductionOp + .getAcc()]]; updateCurrentArgsStatus( loopState, accValIdx, nextAnchorArgs, multireductionOp.getAcc(), nextAnchorArgsIdxMap, originalArgsMap, argsOriginalMap); + loopHelperParam.updateCurrentArgsStatus( + nextAnchorArgsIdxMap, nextAnchorArgs, originalArgsMap, + argsOriginalMap); } + loopHelperParam.anchorIdx += 1; // 2. generate next for loop - scf::ForOp nxtFor = reductionAxisGenerateForLoop( - b, groupIdx, reductionIdx + 1, anchorIdx + 1, - nextAnchorArgsIdxMap, nextAnchorArgs, originalArgsMap, - argsOriginalMap, nextAnchorResults, nextAnchorResultsIdxMap, - inductionVars, forResultOrignalResultMap, - originalResultForResultMap, indiceLoopMap); - + scf::ForOp nxtFor = reductionAxisGenerateForLoop(b, reductionIdx + 1, + loopHelperParam); + loopHelperParam.anchorIdx -= 1; + + loopHelperParam.updateDataBeforePostOpMove( + tmpArgs, currentArgsIdxMap, currentoriginalArgsMap, + currentArgsOriginalMap, nxtFor->getResults(), b.getBlock(), + movedOperation, currentAnchorId); // 3. move postOp to current body - movePostOpToCurrentAnchor( - b, anchorIdx, groupIdx, nxtFor->getResults(), b.getBlock(), - opQueue, movedOperation, inductionVars, currentLoopStateIdxMap, - loopState, currentoriginalArgsMap, currentArgsOriginalMap, - nextAnchorResults, forResultOrignalResultMap, indiceLoopMap); + movePostOpToCurrentAnchor(b, loopHelperParam); // 4. generate loop results - generateLoopResults(b, loc, anchorIdx, groupIdx, nextAnchorResults, - nextAnchorResultsIdxMap, nxtFor->getResults(), - movedOperation, forResultOrignalResultMap, + generateLoopResults(b, loc, currentAnchorId, loopHelperParam.groupIdx, + loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + nxtFor->getResults(), *loopHelperParam.movedOps, + loopHelperParam.nextAnchorResultOrignalResultMap, loopState, currentArgsOriginalMap, nextAnchorArgsIdxMap); // reduction must return accumulate - if (originalResultForResultMap.contains( + if (loopHelperParam.orignalResultNextAnchorResultMap.contains( multireductionOp->getResults()[0])) { Value lastForResult = - originalResultForResultMap[multireductionOp->getResults()[0]]; - size_t retIdx = - nextAnchorArgsIdxMap[forResultOrignalResultMap[lastForResult]]; + loopHelperParam.orignalResultNextAnchorResultMap + [multireductionOp->getResults()[0]]; + size_t retIdx = nextAnchorArgsIdxMap + [loopHelperParam + .nextAnchorResultOrignalResultMap[lastForResult]]; Value forRes = nxtFor->getResults()[retIdx]; // accumulate for loop iter args must be last, so we just put the // reduction result as the last result - nextAnchorResults.emplace_back(forRes); - nextAnchorResultsIdxMap[forRes] = nextAnchorResults.size() - 1; - forResultOrignalResultMap[forRes] = + loopHelperParam.nextAnchorResults.emplace_back(forRes); + loopHelperParam.nextAnchorResultsIdxMap[forRes] = + loopHelperParam.nextAnchorResults.size() - 1; + loopHelperParam.nextAnchorResultOrignalResultMap[forRes] = multireductionOp->getResults()[0]; - originalResultForResultMap[multireductionOp->getResults()[0]] = - forRes; + loopHelperParam.orignalResultNextAnchorResultMap + [multireductionOp->getResults()[0]] = forRes; } - maybeYieldValue(b, loc, nextAnchorResults); + maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); } else if (reductionIdx == reductionAxis.size() - 1) { std::queue movingOperation; @@ -1354,9 +1335,9 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( while (!opQueue.empty()) { Operation *curOp = opQueue.front(); opQueue.pop(); - if (isa(curOp)) { + if (isa(curOp)) break; - } + movingOperation.push(curOp); } // remove all the multi_reduction operation @@ -1369,52 +1350,47 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( break; } - rewriteOperationAsVectorize(b, groupIdx, &movingOperation); + rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, + &movingOperation); moveOperationsToCurrentForBody( - groupIdx, b, inductionVars, currentLoopStateIdxMap, loopState, - originalOperandLoopArgsMap, movingOperation, indiceLoopMap); - // if (!rdCanonicalizer.getIsEmptyReduction()) { - int accValIdx = currentLoopStateIdxMap - [originalOperandLoopArgsMap[multireductionOp.getAcc()]]; - // check acc val is the first args - // assert(accValIdx == 0); + loopHelperParam.groupIdx, b, loopHelperParam.inductionVars, + loopHelperParam.currentLoopStateIdxMap, loopState, + loopHelperParam.originalOperandLoopArgsMap, movingOperation, + loopHelperParam.indiceLoopMap); + loopHelperParam.movedOps = &movingOperation; + loopHelperParam.candidateOps = &opQueue; + + int accValIdx = + loopHelperParam.currentLoopStateIdxMap + [loopHelperParam + .originalOperandLoopArgsMap[multireductionOp.getAcc()]]; Value reductionResult = makeArithReduction( b, loc, multireductionOp.getKind(), multireductionOp.getSource(), loopState[accValIdx]); - movePostOpToCurrentAnchor( - b, anchorIdx, groupIdx, ValueRange(), b.getBlock(), opQueue, - movingOperation, inductionVars, currentLoopStateIdxMap, loopState, - originalOperandLoopArgsMap, loopArgsOriginalOperandMap, - nextAnchorResults, forResultOrignalResultMap, indiceLoopMap); + loopHelperParam.updateDataBeforePostOpMove( + tmpArgs, loopHelperParam.currentLoopStateIdxMap, + loopHelperParam.originalOperandLoopArgsMap, + loopHelperParam.loopArgsOriginalOperandMap, ValueRange(), + b.getBlock(), movingOperation, currentAnchorId); - nextAnchorResults.clear(); - nextAnchorResults.emplace_back(reductionResult); - nextAnchorResultsIdxMap[reductionResult] = 0; - forResultOrignalResultMap[reductionResult] = + movePostOpToCurrentAnchor(b, loopHelperParam); + + loopHelperParam.nextAnchorResults.clear(); + loopHelperParam.nextAnchorResults.emplace_back(reductionResult); + loopHelperParam.nextAnchorResultsIdxMap[reductionResult] = 0; + loopHelperParam.nextAnchorResultOrignalResultMap[reductionResult] = multireductionOp->getResults()[0]; - originalResultForResultMap[multireductionOp->getResults()[0]] = - reductionResult; - getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, - nextAnchorResults, nextAnchorResultsIdxMap, - forResultOrignalResultMap); - - // } else { - // Value sourceVal = multireductionOp.getSource(); - // nextAnchorResults.clear(); - // nextAnchorResults.emplace_back(sourceVal); - // nextAnchorResultsIdxMap[sourceVal] = 0; - // forResultOrignalResultMap[sourceVal] = - // multireductionOp->getResults()[0]; - // originalResultForResultMap[multireductionOp->getResults()[0]] = - // sourceVal; - // getResultInCurrentOps(anchorIdx, groupIdx, movingOperation, - // nextAnchorResults, - // forResultOrignalResultMap); - // } - maybeYieldValue(b, loc, nextAnchorResults); + loopHelperParam.orignalResultNextAnchorResultMap + [multireductionOp->getResults()[0]] = reductionResult; + getResultInCurrentOps( + loopHelperParam.anchorIdx, loopHelperParam.groupIdx, + movingOperation, loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + loopHelperParam.nextAnchorResultOrignalResultMap); + maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); } }); @@ -1424,11 +1400,11 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( /// Generate for loop for parallel axis of `vector.multi_reduction`. /// This function also call reduction axis for loop scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( - OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, + OpBuilder &opBuilder, DenseMap> &indiceLoopMap, GenerateLoopHelper &loopHelperParam) { MultiReductionCanonicalizer &rdCanonicalizer = - getMultiRdCanonicalizers()[groupIdx]; + getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; vector::MultiDimReductionOp &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; VectorType vectorType = rdCanonicalizer.getSourceType(); @@ -1437,8 +1413,9 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( SmallVector ¶llelAxis = rdCanonicalizer.getParallelAxis(); const Location &loc = multiReductionOp.getLoc(); Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); - size_t grpMaxStep = getFusionStrategy().getGroupMaxSteps()[groupIdx]; - size_t actualStep = (parallelIdx == parallelAxis.size() - 1 and + size_t grpMaxStep = + getFusionStrategy().getGroupMaxSteps()[loopHelperParam.groupIdx]; + size_t actualStep = (loopHelperParam.anchorIdx == parallelAxis.size() - 1 and !rdCanonicalizer.getHasLastDimReduction()) ? grpMaxStep : 1; @@ -1446,10 +1423,10 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( // last dim reduction need to a generate dim=16 loop for fused with pre-op int dimSize = 0; - if (parallelIdx == parallelAxis.size()) - dimSize = getFusionStrategy().getGroupMaxSteps()[groupIdx]; + if (loopHelperParam.anchorIdx == parallelAxis.size()) + dimSize = getFusionStrategy().getGroupMaxSteps()[loopHelperParam.groupIdx]; else - dimSize = vectorType.getShape()[parallelAxis[parallelIdx]]; + dimSize = vectorType.getShape()[parallelAxis[loopHelperParam.anchorIdx]]; Value numIter = makeIndexArithConstantOp(opBuilder, loc, dimSize); // Create a loop and move vectorized operation into loops. @@ -1470,7 +1447,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( std::queue &opQueue = opGroups[opIndex]; Value multiReductionAcc = multiReductionOp.getAcc(); - if (parallelIdx < parallelAxis.size()) { + if (loopHelperParam.anchorIdx < parallelAxis.size()) { // 1. move pre-Op to current body DenseMap nextAnchorArgsIdxMap; SmallVector nextAnchorArgs; @@ -1481,17 +1458,15 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( loopHelperParam.loopArgsOriginalOperandMap; DenseMap currentLoopStateIdxMap = loopHelperParam.currentLoopStateIdxMap; - - movePreOpToCurrentAnchor( - parallelIdx, groupIdx, b, loopHelperParam.inductionVars, - loopState, loopHelperParam.currentLoopStateIdxMap, - nextAnchorArgsIdxMap, nextAnchorArgs, opQueue, movedQueue, - loopHelperParam.originalOperandLoopArgsMap, - loopHelperParam.loopArgsOriginalOperandMap, indiceLoopMap); - loopHelperParam.loopIterArgs = nextAnchorArgs; - loopHelperParam.currentLoopStateIdxMap = nextAnchorArgsIdxMap; - - if (parallelIdx == parallelAxis.size() - 1) { + SmallVector tmpArgs(loopState); + loopHelperParam.updateDataBeforePreOpMove(tmpArgs, opQueue, + movedQueue); + movePreOpToCurrentAnchor(b, nextAnchorArgsIdxMap, nextAnchorArgs, + loopHelperParam); + loopHelperParam.updateDataAfterPreOpMove(nextAnchorArgsIdxMap, + nextAnchorArgs); + + if (loopHelperParam.anchorIdx == parallelAxis.size() - 1) { // Ensure accumalate expression appear in this parallel anchor // position. If it not appear in current anchor, we must move it in // here. @@ -1538,38 +1513,32 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( DenseMap originalResultForResultMap; // 2. generate next for loop if (rdCanonicalizer.hasLastDimReduction() or - parallelIdx < parallelAxis.size() - 1) - nxtFor = parallelAxisGenerateForLoop( - b, groupIdx, parallelIdx + 1, indiceLoopMap, loopHelperParam); - else if (parallelAxis.size() - 1 == parallelIdx) - nxtFor = reductionAxisGenerateForLoop( - b, groupIdx, 0, parallelIdx + 1, nextAnchorArgsIdxMap, - nextAnchorArgs, loopHelperParam.originalOperandLoopArgsMap, - loopHelperParam.loopArgsOriginalOperandMap, - loopHelperParam.nextAnchorResults, - loopHelperParam.nextAnchorResultsIdxMap, - loopHelperParam.inductionVars, - loopHelperParam.nextAnchorResultOrignalResultMap, - originalResultForResultMap, indiceLoopMap); + loopHelperParam.anchorIdx < parallelAxis.size() - 1) { + loopHelperParam.anchorIdx += 1; + nxtFor = + parallelAxisGenerateForLoop(b, indiceLoopMap, loopHelperParam); + } else if (parallelAxis.size() - 1 == loopHelperParam.anchorIdx) { + loopHelperParam.anchorIdx += 1; + nxtFor = reductionAxisGenerateForLoop(b, 0, loopHelperParam); + } + loopHelperParam.anchorIdx -= 1; + loopHelperParam.updateDataBeforePostOpMove( + tmpArgs, currentLoopStateIdxMap, currentOriginalOperandMap, + currentOperandOriginalMap, nxtFor->getResults(), + nxtFor->getBlock(), movedQueue, loopHelperParam.anchorIdx); // 3. move postOp to current body - movePostOpToCurrentAnchor( - b, parallelIdx, groupIdx, nxtFor->getResults(), - nxtFor->getBlock(), opQueue, movedQueue, - loopHelperParam.inductionVars, currentLoopStateIdxMap, loopState, - currentOriginalOperandMap, currentOperandOriginalMap, - loopHelperParam.nextAnchorResults, - loopHelperParam.nextAnchorResultOrignalResultMap, indiceLoopMap); - + movePostOpToCurrentAnchor(b, loopHelperParam); // 4. generate loop results generateLoopResults( - b, loc, parallelIdx, groupIdx, loopHelperParam.nextAnchorResults, + b, loc, loopHelperParam.anchorIdx, loopHelperParam.groupIdx, + loopHelperParam.nextAnchorResults, loopHelperParam.nextAnchorResultsIdxMap, nxtFor->getResults(), movedQueue, loopHelperParam.nextAnchorResultOrignalResultMap, loopState, currentOperandOriginalMap, nextAnchorArgsIdxMap); maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); - } else if (parallelIdx == parallelAxis.size()) { + } else if (loopHelperParam.anchorIdx == parallelAxis.size()) { DenseMap tmpOriginOperandLoopArgsMap = loopHelperParam.originalOperandLoopArgsMap; @@ -1590,7 +1559,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( DenseMap localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap; - SmallVector argsArray; + SmallVector argsArray; argsArray.emplace_back(accVal); localAnchorArgsIdxMap[accVal] = 0; size_t accLoopStateIdx = @@ -1611,15 +1580,11 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( localOriginalOperandLoopArgsMap[originalValue] = x; localLoopArgsOriginalOperandMap[x] = originalValue; } + loopHelperParam.updateCurrentArgsStatus( + localAnchorArgsIdxMap, argsArray, localOriginalOperandLoopArgsMap, + localLoopArgsOriginalOperandMap); DenseMap originalResultForResultMap; - auto nxtFor = reductionAxisGenerateForLoop( - b, groupIdx, 0, parallelIdx, localAnchorArgsIdxMap, argsArray, - localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap, - loopHelperParam.nextAnchorResults, - loopHelperParam.nextAnchorResultsIdxMap, - loopHelperParam.inductionVars, - loopHelperParam.nextAnchorResultOrignalResultMap, - originalResultForResultMap, indiceLoopMap); + auto nxtFor = reductionAxisGenerateForLoop(b, 0, loopHelperParam); // insert accumulate value to original vector Value nxtForAccVal = @@ -1656,14 +1621,9 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( currentResultIdxMap[nxtFor->getResults()[idx]] = itrIdx; currentResultMap[nxtFor->getResults()[idx]] = originalResult; } - loopHelperParam.nextAnchorResults.clear(); - loopHelperParam.nextAnchorResultOrignalResultMap.clear(); - loopHelperParam.nextAnchorResultsIdxMap.clear(); - loopHelperParam.nextAnchorResults = std::move(currentAnchorResults); - loopHelperParam.nextAnchorResultOrignalResultMap = - std::move(currentResultMap); - loopHelperParam.nextAnchorResultsIdxMap = - std::move(currentResultIdxMap); + loopHelperParam.clearNextAnchorResults(); + loopHelperParam.setNextAnchorResults( + currentAnchorResults, currentResultMap, currentResultIdxMap); maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); } }); @@ -1835,7 +1795,7 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { rearrageMultiReductionIR(grpIdx, indiceLoopMap); // get current loop init args DenseMap currentLoopStateIdxMap, nextAnchorResultsIdxMap; - GenerateLoopHelper loopHelper; + GenerateLoopHelper loopHelper(grpIdx, 0); prepareForLoopArgs(grpIdx, loopHelper); MultiReductionCanonicalizer &rdCanonicalizer = @@ -1843,8 +1803,8 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); - scf::ForOp forOp = parallelAxisGenerateForLoop(opBuilder, grpIdx, 0, - indiceLoopMap, loopHelper); + scf::ForOp forOp = + parallelAxisGenerateForLoop(opBuilder, indiceLoopMap, loopHelper); replaceOpUsersWithForLoopResult(forOp, grpIdx, loopHelper.nextAnchorResults, loopHelper.nextAnchorResultsIdxMap, loopHelper.nextAnchorResultOrignalResultMap); @@ -2074,9 +2034,8 @@ void ForLoopGenerator::rectifyWriteOperationIndice( successWriteTensorType.getRank() - sucessWriteVectorType.getRank(); Operation::operand_range writeIndices = originalWriteOp->getIndices(); - for (size_t i = 0; i < inMutableIdx; i++) { + for (size_t i = 0; i < inMutableIdx; i++) writeVars[i] = writeIndices[i]; - } } void ForLoopGenerator::rectifyReadOperationIndice( @@ -2087,9 +2046,8 @@ void ForLoopGenerator::rectifyReadOperationIndice( // currently only broadcast (fuse as transfer_read) will move into more inner // loop if (readTensorType.getRank() - 1 >= - getFusionStrategy().getOpAnchorPos()[*originalReadOp]) { + getFusionStrategy().getOpAnchorPos()[*originalReadOp]) return; - } int64_t itrIdx = loopType.getRank() - 1; int64_t readIdx = readTensorType.getRank() - 1; @@ -2117,15 +2075,14 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { OpBuilder b(scOp); SmallVector iterArgs; SmallVector successorWriteOps; - for (Operation *x : scOp->getUsers()) { + for (Operation *x : scOp->getUsers()) if (isa(x) and opIndexMap.contains(x) and - opIndexMap[x] == opIndexMap[x]) { + opIndexMap[x] == opIndexMap[x]) successorWriteOps.emplace_back(cast(x)); - } - } - for (auto successorWriteOp : successorWriteOps) { + + for (auto successorWriteOp : successorWriteOps) iterArgs.emplace_back(successorWriteOp->getOperands()[1]); - } + SmallVector inductionVars; IRRewriter rewriter(func); const size_t groupStep = getFusionStrategy().getGroupMaxSteps()[grpIdx]; @@ -2135,24 +2092,22 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { bool isDestMultiple = destType.getShape()[destType.getRank() - 1] % groupStep == 0; - if (isDestMultiple and isSourceMultiple and - scCanonicalizer.isReadWriteOnLastDim()) { - scf::ForOp forOp = generateShapeCastReadWriteLoop( + scf::ForOp forOp; + bool canVectorizedLoadStore = isDestMultiple and isSourceMultiple and + scCanonicalizer.isReadWriteOnLastDim(); + if (canVectorizedLoadStore) { + forOp = generateShapeCastReadWriteLoop( b, grpIdx, 0, groupStep, scOp.getLoc(), inductionVars, iterArgs); - for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) { + for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) + rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); + } else { + // scalar data movement + forOp = generateShapeCastReadWriteLoop(b, grpIdx, 0, 1, scOp.getLoc(), + inductionVars, iterArgs); + for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); - } - rewriter.eraseOp(scOp); - clearCurrentOperationGroup(grpIdx); - return forOp; } - // scalar data movement - scf::ForOp forOp = generateShapeCastReadWriteLoop( - b, grpIdx, 0, 1, scOp.getLoc(), inductionVars, iterArgs); - for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) { - rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); - } rewriter.eraseOp(scOp); clearCurrentOperationGroup(grpIdx); return forOp; @@ -2210,7 +2165,7 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { DenseMap originalOperandMap, operandOriginalMap, resultIdxMap, forResultOrignalResultMap; SmallVector iterArgs; - GenerateLoopHelper loopHelper; + GenerateLoopHelper loopHelper(grpIdx); prepareForLoopArgs(grpIdx, loopHelper); operandIdxMap = loopHelper.currentLoopStateIdxMap; originalOperandMap = loopHelper.originalOperandLoopArgsMap; @@ -2297,10 +2252,6 @@ SmallVector &SpecialOperationCanonicalizer::getCandidateOps() { }; void MultiReductionCanonicalizer::initReductionAxis() { - // auto reductionAxisRange = - // getCandidateOps()[0].getReductionDims().getAsValueRange(); - // auto reductionRange = llvm::to_vector<4>(map_range( - // reductionAxisRange, [](const APInt &a) { return a.getZExtValue(); })); auto reductionAxisRange = getCandidateOps()[0].getReductionDims(); reductionAxis.assign(reductionAxisRange.begin(), reductionAxisRange.end()); llvm::sort(reductionAxis); @@ -2344,7 +2295,7 @@ void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { return; sourceType = getCandidateOps()[0].getSourceVectorType(); - accType = dyn_cast(getCandidateOps()[0].getAcc().getType()); + accType = cast(getCandidateOps()[0].getAcc().getType()); getTypeRank(); getReductionAxisAndParallelAxis(); hasLastDimReduction(); @@ -2615,7 +2566,7 @@ void CanonicalizerVectorOperation::run() { for (size_t idx = 0; idx < fusionStrategy.getOpGroups().size(); ++idx) { generateGroupOpVectorizedIR(idx); } - + func->dump(); // 3. Some IR cleanup work DominanceInfo domInfo; eliminateCommonSubExpressions(rewriter, domInfo, func); @@ -2645,12 +2596,9 @@ void CanonicalizerVectorOperation::run() { return false; } - if (mlir::isa(op) || - mlir::isa(op)) { - if (!isReadWriteOnLastDim(op)) { - LDBG("Operation is not last dim read/write" << *op << "\n"); - return false; - } + if (isReadOrWriteOperation(op) and !isReadWriteOnLastDim(op)) { + LDBG("Operation is not last dim read/write" << *op << "\n"); + return false; } return true; @@ -2722,15 +2670,14 @@ void ForLoopGenerator::setOperationCorrectOperand( } scf::ForOp ForLoopGenerator::constructNestedForOp( - const size_t forDimIdx, const size_t groupIdx, OpBuilder &b, - const Location &loc, ArrayRef dims, - GenerateLoopHelper &loopHelper) { + const size_t groupIdx, OpBuilder &b, const Location &loc, + ArrayRef dims, GenerateLoopHelper &loopHelper) { const int loop_step = getFusionStrategy().getGroupMaxSteps()[groupIdx]; // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); auto forSteps = makeIndexArithConstantOp( - b, loc, forDimIdx == dims.size() - 1 ? loop_step : 1); - auto numIter = makeIndexArithConstantOp(b, loc, dims[forDimIdx]); + b, loc, loopHelper.anchorIdx == dims.size() - 1 ? loop_step : 1); + auto numIter = makeIndexArithConstantOp(b, loc, dims[loopHelper.anchorIdx]); // Create a loop and move vectorized operation into loops. auto forOp = b.create( @@ -2739,12 +2686,13 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( loopHelper.inductionVars.emplace_back(iv); // inner most body of the loop - if (forDimIdx == dims.size() - 1) { + if (loopHelper.anchorIdx == dims.size() - 1) { std::queue &opQueue = getFusionStrategy().getOpGroups()[groupIdx]; // 1. get operations in current anchor position std::queue movingOperation; - getOperationInCurrentAnchor(forDimIdx, opQueue, movingOperation); + getOperationInCurrentAnchor(loopHelper.anchorIdx, opQueue, + movingOperation); // 2. rewrite operation as vectorize IR rewriteOperationAsVectorize(b, groupIdx, &movingOperation); @@ -2756,7 +2704,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( loopHelper.originalOperandLoopArgsMap, movingOperation, loopHelper.indiceLoopMap); - getResultInCurrentOps(forDimIdx, groupIdx, movingOperation, + getResultInCurrentOps(loopHelper.anchorIdx, groupIdx, movingOperation, loopHelper.nextAnchorResults, loopHelper.nextAnchorResultsIdxMap, loopHelper.nextAnchorResultOrignalResultMap); @@ -2777,32 +2725,34 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( std::queue movedQueue; std::queue &opQueue = getFusionStrategy().getOpGroups()[groupIdx]; - - movePreOpToCurrentAnchor( - forDimIdx, groupIdx, b, loopHelper.inductionVars, loopState, - loopHelper.currentLoopStateIdxMap, nextAnchorArgsIdxMap, - nextAnchorArgs, opQueue, movedQueue, - loopHelper.originalOperandLoopArgsMap, - loopHelper.loopArgsOriginalOperandMap, loopHelper.indiceLoopMap); - - loopHelper.loopIterArgs = nextAnchorArgs; - loopHelper.currentLoopStateIdxMap = nextAnchorArgsIdxMap; - auto nxtFor = constructNestedForOp(forDimIdx + 1, groupIdx, b, loc, - dims, loopHelper); - - movePostOpToCurrentAnchor( - b, forDimIdx, groupIdx, nxtFor->getResults(), b.getBlock(), - opQueue, movedQueue, loopHelper.inductionVars, currentArgsIdxMap, - loopState, currentOriginalOperandMap, currentOperandOriginalMap, - loopHelper.nextAnchorResults, - loopHelper.nextAnchorResultOrignalResultMap, - loopHelper.indiceLoopMap); + SmallVector tmpArgs(loopState); + loopHelper.updateDataBeforePreOpMove(tmpArgs, opQueue, movedQueue); + movePreOpToCurrentAnchor(b, nextAnchorArgsIdxMap, nextAnchorArgs, + loopHelper); + loopHelper.updateDataAfterPreOpMove(nextAnchorArgsIdxMap, + nextAnchorArgs); + loopHelper.anchorIdx += 1; + auto nxtFor = + constructNestedForOp(groupIdx, b, loc, dims, loopHelper); + loopHelper.anchorIdx -= 1; + SmallVector currentArgs(loopState); + + loopHelper.updateCurrentArgsStatus(currentArgsIdxMap, currentArgs, + currentOriginalOperandMap, + currentOperandOriginalMap); + + loopHelper.updateDataBeforePostOpMove( + tmpArgs, currentArgsIdxMap, currentOriginalOperandMap, + currentOperandOriginalMap, nxtFor->getResults(), b.getBlock(), + movedQueue, loopHelper.anchorIdx); + movePostOpToCurrentAnchor(b, loopHelper); generateLoopResults( - b, loc, forDimIdx, groupIdx, loopHelper.nextAnchorResults, - loopHelper.nextAnchorResultsIdxMap, nxtFor->getResults(), - movedQueue, loopHelper.nextAnchorResultOrignalResultMap, - loopState, currentOperandOriginalMap, nextAnchorArgsIdxMap); + b, loc, loopHelper.anchorIdx, groupIdx, + loopHelper.nextAnchorResults, loopHelper.nextAnchorResultsIdxMap, + nxtFor->getResults(), movedQueue, + loopHelper.nextAnchorResultOrignalResultMap, loopState, + currentOperandOriginalMap, nextAnchorArgsIdxMap); maybeYieldValue(b, loc, loopHelper.nextAnchorResults); } @@ -2979,72 +2929,43 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { }); } -bool readWriteDependency(Operation *op1, Operation *op2) { - if (not(isReadOrWriteOperation(op1) and isReadOrWriteOperation(op2))) { - return false; - } - auto readWriteOrder = [](Operation *op1, Operation *op2) { - if (isa(op1) and - isa(op2)) { - return true; - } - return false; - }; - if (!readWriteOrder(op1, op2) and !readWriteOrder(op2, op1)) { - return false; - } - - // e.g.: if op1 is read the value and pass it to op2, it is not data - // dependency - if (isOperationsHasDefUseRelation(op1, op2)) { - return false; - } - return true; -} - static inline bool hasSameAxis(ArrayRef dims1, ArrayRef dims2) { DenseSet checkSet(dims2.begin(), dims2.end()); - for (auto x : dims1) { - if (checkSet.contains(x)) { + for (auto x : dims1) + if (checkSet.contains(x)) return true; - } - } + return false; } /// whether two operation has data dependency /// op1 default is previous operation, op2 default is current operation bool hasDataDependency(Operation *op1, Operation *op2) { - if (!isSpecialOp(op1) and !isSpecialOp(op2)) { + if (!isSpecialOp(op1) and !isSpecialOp(op2)) return false; - } // TODO: Remove this condition to support special operation fusion in the // future - if (disableSpecialOp) { + if (disableSpecialOp) return true; - } if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { // if op1 is read the value and pass it to op2, it is not data dependency - if (isOperationsHasDefUseRelation(op1, op2)) { + if (isOperationsHasDefUseRelation(op1, op2)) return false; - } } // broadcast only fuse with post-op - if (isa(op2)) { + if (isa(op2)) return true; - } - if (isa(op1) and disableBroadcastOp) { + + if (isa(op1) and disableBroadcastOp) return true; - } // only special operation may cause data dependency - if (!isSpecialOp(op1)) { + if (!isSpecialOp(op1)) return hasDataDependency(op2, op1); - } auto res = TypeSwitch(op1) @@ -3052,9 +2973,9 @@ bool hasDataDependency(Operation *op1, Operation *op2) { SmallVector dims1, dims2; getOperationDataAxis(op1, dims1); getOperationDataAxis(op2, dims2); - if (!isSpecialOp(op2)) { + if (!isSpecialOp(op2)) return hasSameAxis(dims1, dims2); - } + return true; }) .Case( @@ -3073,9 +2994,9 @@ bool hasDataDependency(Operation *op1, Operation *op2) { break; } } - if (!reduceDependent) { + if (!reduceDependent) return false; - } + // all parallel axis should equal to op2's axis checkSet.clear(); checkSet.insert(reductionDims.begin(), reductionDims.end()); @@ -3088,34 +3009,25 @@ bool hasDataDependency(Operation *op1, Operation *op2) { checkSet.clear(); checkSet.insert(parallelDims.begin(), parallelDims.end()); auto rank = op2VectorType->getRank(); - for (auto i = 0; i < rank; i++) { - if (!checkSet.contains(i)) { + for (auto i = 0; i < rank; i++) + if (!checkSet.contains(i)) return true; - } - } + return false; } return true; }) .Case([&](vector::BroadcastOp broadcastOp) { - if (isSpecialOp(op2)) { + if (isSpecialOp(op2)) return true; - } + return !OpTrait::util::staticallyKnownBroadcastable( getOperationVectorType(op1, false)->getShape(), getOperationVectorType(op2)->getShape()); }) - .Case([&](vector::TransposeOp transposeOp) { - return true; - // SmallVector dims1, dims2; - // getOperationDataAxis(op1, dims1); - // getOperationDataAxis(op2, dims2); - // if (!isSpecialOp(op2)) { - // return hasSameAxis(dims1, dims2); - // } - // return true; - }) + .Case( + [&](vector::TransposeOp transposeOp) { return true; }) .Default([&](Operation *op) { return false; }); return res; @@ -3128,9 +3040,9 @@ Operation *getNotReadWriteOperaiton(std::queue &tmpQ) { while (!tmpQ.empty()) { Operation *cur = tmpQ.front(); tmpQ.pop(); - if (isReadOrWriteOperation(cur)) { + if (isReadOrWriteOperation(cur)) continue; - } + op = cur; } return op; @@ -3244,7 +3156,7 @@ Value setOutGroupOperationOperandResult(Operation *op, if (isa(value)) { auto valueType = mlir::dyn_cast(value); if (valueType.isSplat()) { - if (mlir::isa(valueType.getElementType())) + if (isa(valueType.getElementType())) initValueAttr = FloatAttr::get( resultElementType, valueType.getSplatValue().convertToDouble()); @@ -3314,8 +3226,6 @@ void ForLoopGenerator::createNewConstantOp( newConstantOp = res.value().getDefiningOp(); } else { // TODO: need to test not splat value - // newConstantOp = srcWriter.create( - // srcOp->getLoc(), srcConstantOp.getValue()); llvm::llvm_unreachable_internal( "Can't support not splat constant value."); } @@ -3456,10 +3366,9 @@ void CanonicalizerCommonUsedData::removeOpInCurrentGroups( // update removed operation related operation anchor position getFusionStrategy().getOpAnchorPos()[replacedOp] = getOperationMaxVectorType(replacedOp)->getRank() - 1; - for (Operation *x : usesOp) { + for (Operation *x : usesOp) getFusionStrategy().getOpAnchorPos()[x] = getOperationMaxVectorType(x)->getRank() - 1; - } // update operation in grpIdx group related information updateOpGroupInfo(grpIdx); @@ -3477,9 +3386,8 @@ void CanonicalizerCommonUsedData::updateOpGroupInfo(size_t grpIdx) { auto curOp = tmpOpQueue.front(); tmpOpQueue.pop(); VectorType type = getOperationMaxVectorType(curOp).value(); - if (type.getRank() > currentMaxRankType.getRank()) { + if (type.getRank() > currentMaxRankType.getRank()) getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx] = type; - } } } @@ -3502,16 +3410,6 @@ void CanonicalizerCommonUsedData::updateOpOperandResultInGroups( getFusionStrategy().getOpAnchorPos()[init.getDefiningOp()] = getFusionStrategy().getOpAnchorPos()[op]; } - // directly use the read operation to do the fusion - // if (isa(op) and not disableBroadcastOp) { - // auto srcOp = op->getOperand(0).getDefiningOp(); - // getFusionStrategy().getOpAnchorPos()[srcOp] = - // getFusionStrategy().getOpAnchorPos()[op]; - - // IRRewriter rewrite(op); - // rewrite.replaceOp(op, srcOp); - // continue; - // } newOpQueue.push(op); if (result && !failed(getOperationVectorType(result.getDefiningOp()))) { @@ -3561,11 +3459,9 @@ Operation *CanonicalizerCommonUsedData::getNextTargetOperationInCurrentGroup( while (!tmpOpQueue.empty()) { auto frontOp = tmpOpQueue.front(); if (isa(frontOp)) { - for (auto x : frontOp->getOperands()) { - if (x.getDefiningOp() == curOp) { + for (auto x : frontOp->getOperands()) + if (x.getDefiningOp() == curOp) return frontOp; - } - } } tmpOpQueue.pop(); } @@ -3594,9 +3490,9 @@ void VectorOperationAnalyzer::analysisGroupMaxSteps() { llvm::SmallVector &grpSteps = getFusionStrategy().getGroupMaxSteps(); - while (idx + 1 > grpSteps.size()) { + while (idx + 1 > grpSteps.size()) grpSteps.emplace_back(steps); - } + std::queue tmpQueue(grp); auto calculateOpSteps = [&](Type type) { auto opType = dyn_cast(type); @@ -3660,9 +3556,8 @@ void VectorOperationAnalyzer::updateReturnResultKind(Operation *sourceOp, getFusionStrategy().getOpAnchorPos(); Value sourceResult = sourceOp->getResults()[0]; - if (srcOpCanoniclizedMap.contains(sourceOp)) { + if (srcOpCanoniclizedMap.contains(sourceOp)) sourceResult = srcOpCanoniclizedMap[sourceOp].second; - } size_t srcOpAnchor = groupOpResults[sourceOpGid][sourceResult].second; ReturnTypeKind prevRtKind = groupOpResults[sourceOpGid][sourceResult].first; @@ -3868,13 +3763,13 @@ void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( const size_t groupId, IRRewriter &rewriter, VectorType vectorType) { // prepare for loop iterargs - GenerateLoopHelper loopHelper; + GenerateLoopHelper loopHelper(groupId, 0); prepareForLoopArgs(groupId, loopHelper); ArrayRef shapes = vectorType.getShape(); // generate for loop - auto forOp = constructNestedForOp( - 0, groupId, rewriter, rewriter.getUnknownLoc(), shapes, loopHelper); + auto forOp = constructNestedForOp(groupId, rewriter, rewriter.getUnknownLoc(), + shapes, loopHelper); replaceOpUsersWithForLoopResult(forOp, groupId, loopHelper.nextAnchorResults, loopHelper.nextAnchorResultsIdxMap, loopHelper.nextAnchorResultOrignalResultMap); diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index eee1d3a12..3c12b30ac 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -79,6 +79,14 @@ struct HardWareInfo { //===----------------------------------------------------------------------===// /// Using to avoid too many parameters in function struct GenerateLoopHelper { + /// anchor id + size_t anchorIdx = 0; + /// group id + size_t groupIdx = 0; + /// for loop results + ValueRange forResults; + /// for loop block + Block *forBlock; /// loop iteration args index map DenseMap currentLoopStateIdxMap; /// loop iteration args @@ -89,18 +97,135 @@ struct GenerateLoopHelper { DenseMap nextAnchorResultsIdxMap; /// next loop anchor yield results original result map DenseMap nextAnchorResultOrignalResultMap; + /// original result with next anchor result map + DenseMap orignalResultNextAnchorResultMap; /// loop induction variables SmallVector inductionVars; /// original operand with loop args map DenseMap originalOperandLoopArgsMap; /// loop args with original operand map DenseMap loopArgsOriginalOperandMap; - /// record operation's correct loop indice, due to some operation like reduce - /// may need to reorder loop indice + /// candidate operation queue + std::queue *candidateOps; + /// moved operation queue + std::queue *movedOps; + /// record operation's correct loop indice, due to some operation like + /// reduce may need to reorder loop indice DenseMap> indiceLoopMap; GenerateLoopHelper() = default; + GenerateLoopHelper(const size_t groupIdx) noexcept { + this->groupIdx = groupIdx; + } + GenerateLoopHelper(const size_t groupIdx, const size_t anchorIdx) noexcept { + this->groupIdx = groupIdx; + this->anchorIdx = anchorIdx; + } + /// clear next anchor results related data + void clearNextAnchorResults(); + /// set next anchor results related data + void setNextAnchorResults(SmallVector ¤tAnchorResults, + DenseMap ¤tResultMap, + DenseMap ¤tResultIdxMap); + /// set next anchor iteration args + void setNextAnchorArgs(DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs); + /// set id of for loop anchor + void setAnchorId(const size_t anchorId) noexcept; + /// Before perform processing previous operation, we need to update some data + void updateDataBeforePreOpMove(ArrayRef loopstate, + std::queue &candidateQueue, + std::queue &movedQueue); + /// After previous operation movement, we need to update some data + void updateDataAfterPreOpMove(DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs); + /// Before perform processing previous operation, we need to update some data + void updateDataBeforePostOpMove( + ArrayRef iterArgs, DenseMap ¤tLoopStateIdxMap, + DenseMap ¤toriginalArgsMap, + DenseMap ¤tArgsOriginalMap, ValueRange forResults, + Block *forBlock, std::queue &movedQueue, size_t anchorId); + /// After previous operation movement, we need to update some data + void updateDataAfterPostOpMove(size_t anchorId, + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs); + + void updateCurrentArgsStatus(DenseMap ¤tArgsIdxMap, + SmallVector ¤tArgs, + DenseMap &originalArgsMap, + DenseMap &argsOriginalMap); }; +void GenerateLoopHelper::setNextAnchorArgs( + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + currentLoopStateIdxMap = nextAnchorArgsIdxMap; + loopIterArgs = nextAnchorArgs; +} + +void GenerateLoopHelper::clearNextAnchorResults() { + nextAnchorResults.clear(); + nextAnchorResultsIdxMap.clear(); + nextAnchorResultOrignalResultMap.clear(); +} + +void GenerateLoopHelper::setAnchorId(size_t anchorId) noexcept { + anchorIdx = anchorId; +} + +void GenerateLoopHelper::updateDataBeforePreOpMove( + ArrayRef loopState, std::queue &candidateQueue, + std::queue &movedQueue) { + loopIterArgs = loopState; + candidateOps = &candidateQueue; + movedOps = &movedQueue; +} + +void GenerateLoopHelper::updateDataAfterPreOpMove( + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + setNextAnchorArgs(nextAnchorArgsIdxMap, nextAnchorArgs); +} + +void GenerateLoopHelper::updateDataBeforePostOpMove( + ArrayRef iterArgs, DenseMap ¤tLoopStateIdxMap, + DenseMap ¤toriginalArgsMap, + DenseMap ¤tArgsOriginalMap, ValueRange forResults, + Block *forBlock, std::queue &movedQueue, size_t anchorId) { + this->originalOperandLoopArgsMap = currentoriginalArgsMap; + this->loopArgsOriginalOperandMap = currentArgsOriginalMap; + this->forResults = forResults; + this->forBlock = forBlock; + this->anchorIdx = anchorId; + this->currentLoopStateIdxMap = currentLoopStateIdxMap; + this->loopIterArgs = iterArgs; + this->movedOps = &movedQueue; +} + +void GenerateLoopHelper::updateDataAfterPostOpMove( + size_t anchorId, DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + setAnchorId(anchorId); + setNextAnchorArgs(nextAnchorArgsIdxMap, nextAnchorArgs); +} + +void GenerateLoopHelper::setNextAnchorResults( + SmallVector ¤tAnchorResults, + DenseMap ¤tResultMap, + DenseMap ¤tResultIdxMap) { + nextAnchorResults = std::move(currentAnchorResults); + nextAnchorResultOrignalResultMap = std::move(currentResultMap); + nextAnchorResultsIdxMap = std::move(currentResultIdxMap); +} + +void GenerateLoopHelper::updateCurrentArgsStatus( + DenseMap ¤tArgsIdxMap, SmallVector ¤tArgs, + DenseMap &originalArgsMap, + DenseMap &argsOriginalMap) { + setNextAnchorArgs(currentArgsIdxMap, currentArgs); + originalOperandLoopArgsMap = originalArgsMap; + loopArgsOriginalOperandMap = argsOriginalMap; +} + /// Vector type conversion helper class class TypeHelper { private: @@ -547,9 +672,8 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { generateVectorizedForLoop(const size_t groupId, IRRewriter &rewriter, const VectorType vectorType); - scf::ForOp constructNestedForOp(const size_t forDimIdx, const size_t groupIdx, - OpBuilder &b, const Location &loc, - ArrayRef dims, + scf::ForOp constructNestedForOp(const size_t groupIdx, OpBuilder &b, + const Location &loc, ArrayRef dims, GenerateLoopHelper &loopGenerator); void moveOperationsToCurrentForBody( @@ -599,30 +723,13 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { DenseMap &nextOperandIdxMap); /// todo: need to add a struct to remove so many parameters - void movePostOpToCurrentAnchor( - OpBuilder &b, const int anchorIdx, const int groupIdx, - const ValueRange &forResults, const Block *forBlock, - std::queue &candidateOps, - std::queue &movedOperation, ArrayRef inductionVars, - const llvm::DenseMap &operandIdxMap, - const ValueRange &loopState, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap, - const llvm::SmallVector &nextAnchorResults, - DenseMap &forResultOrignalResultMap, - DenseMap> &indiceLoopMap); + void movePostOpToCurrentAnchor(OpBuilder &b, + GenerateLoopHelper &loopHelperParam); - void movePreOpToCurrentAnchor( - const size_t anchorIdx, const size_t groupIdx, OpBuilder &b, - ArrayRef inductionVars, const ValueRange &loopState, - llvm::DenseMap ¤tLoopStateIdxMap, - llvm::DenseMap &nextLoopStateIdxMap, - llvm::SmallVector &nextAnchorArgs, - std::queue &candidateQueue, - std::queue &movedQueue, - DenseMap &originalOperandLoopArgsMap, - DenseMap &LoopArgsoriginalOperandMap, - DenseMap> &indiceLoopMap); + void movePreOpToCurrentAnchor(OpBuilder &b, + DenseMap &nextLoopStateIdxMap, + SmallVector &nextAnchorArgs, + GenerateLoopHelper &loopHelperParam); void replaceOperationsWithForLoopResult( IRRewriter &rewrite, const ValueRange &forResults, const Block *forBlock, @@ -636,21 +743,13 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { void rearrageMultiReductionIR( const size_t grpIdx, DenseMap> &indiceLoopMap); - scf::ForOp reductionAxisGenerateForLoop( - OpBuilder &opBuilder, const int groupIdx, const size_t reductionIdx, - const int anchorIdx, llvm::DenseMap ¤tLoopStateIdxMap, - const ValueRange &initArgs, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap, - llvm::SmallVector &nextAnchorResults, - llvm::DenseMap &nextAnchorResultsIdxMap, - llvm::SmallVector &inductionVars, - DenseMap &forResultOrignalResultMap, - DenseMap &originalResultForResultMap, - DenseMap> &indiceLoopMap); + + scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder, + const size_t reductionIdx, + GenerateLoopHelper &loopHelperParam); scf::ForOp parallelAxisGenerateForLoop( - OpBuilder &opBuilder, const int groupIdx, const size_t parallelIdx, + OpBuilder &opBuilder, DenseMap> &indiceLoopMap, GenerateLoopHelper &loopHelperParam); From db59850f4edd6550be59530b8936206786efc94b Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Sat, 7 Sep 2024 22:06:48 +0800 Subject: [PATCH 40/66] simplify function parameters --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 203 +++++++----------- lib/gc/Transforms/TilingVector.h | 88 +++----- 2 files changed, 107 insertions(+), 184 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 3130760e5..8208cf877 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -959,11 +959,8 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, } void ForLoopGenerator::moveOperationsToCurrentForBody( - const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, - const DenseMap &operandIdxMap, ValueRange loopState, - DenseMap &originalOperandLoopArgsMap, - std::queue &opQueue, - DenseMap> &indiceLoopMap) { + const OpBuilder &b, std::queue &opQueue, + GenerateLoopHelper &loopHelperParam) { auto &opPermuationMap = getOpPermuationMap(); auto tmpQ(opQueue); while (!tmpQ.empty()) { @@ -971,9 +968,11 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( tmpQ.pop(); x->moveBefore(b.getBlock(), b.getBlock()->end()); // check operation type to set correct operand - setOperationCorrectOperand(x, loopState, operandIdxMap, - originalOperandLoopArgsMap, inductionVars, - opPermuationMap, indiceLoopMap); + setOperationCorrectOperand(x, loopHelperParam.loopIterArgs, + loopHelperParam.currentLoopStateIdxMap, + loopHelperParam.originalOperandLoopArgsMap, + loopHelperParam.inductionVars, opPermuationMap, + loopHelperParam.indiceLoopMap); } } @@ -1024,20 +1023,16 @@ void updateCurrentArgsStatus(ValueRange loopState, const size_t loopStateIdx, } void ForLoopGenerator::getInitArgsToNextAnchor( - const size_t anchorIdx, const size_t groupId, - const std::queue &nextOperations, const ValueRange &loopState, - DenseMap ¤tLoopStateIdxMap, DenseMap &nextAnchorArgsIdxMap, SmallVector &nextAnchorArgs, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap) { + GenerateLoopHelper &loopHelperParam) { DenseMap &opAnchorPos = getFusionStrategy().getOpAnchorPos(); - SetVector &opInitArgs = getGroupOpInitArgs()[groupId]; + SetVector &opInitArgs = getGroupOpInitArgs()[loopHelperParam.groupIdx]; DenseSet visited; // find the next anchor arguments - std::queue tmpQ(nextOperations); + std::queue tmpQ(*loopHelperParam.candidateOps); DenseMap nextOriginalOperandMap, nextOperandOriginalMap; while (!tmpQ.empty()) { @@ -1046,19 +1041,21 @@ void ForLoopGenerator::getInitArgsToNextAnchor( auto curOperands = cur->getOperands(); for (auto x : curOperands) { if (!visited.contains(x) and opInitArgs.contains(x) and - opAnchorPos[cur] > anchorIdx) { - assert(originalOperandLoopArgsMap.contains(x)); - int loopStateIdx = - currentLoopStateIdxMap[originalOperandLoopArgsMap[x]]; - updateCurrentArgsStatus(loopState, loopStateIdx, nextAnchorArgs, x, - nextAnchorArgsIdxMap, nextOriginalOperandMap, - nextOperandOriginalMap); + opAnchorPos[cur] > loopHelperParam.anchorIdx) { + assert(loopHelperParam.originalOperandLoopArgsMap.contains(x)); + int loopStateIdx = loopHelperParam.currentLoopStateIdxMap + [loopHelperParam.originalOperandLoopArgsMap[x]]; + updateCurrentArgsStatus(loopHelperParam.loopIterArgs, loopStateIdx, + nextAnchorArgs, x, nextAnchorArgsIdxMap, + nextOriginalOperandMap, nextOperandOriginalMap); visited.insert(x); } } } - originalOperandLoopArgsMap = std::move(nextOriginalOperandMap); - loopArgsOriginalOperandMap = std::move(nextOperandOriginalMap); + loopHelperParam.originalOperandLoopArgsMap = + std::move(nextOriginalOperandMap); + loopHelperParam.loopArgsOriginalOperandMap = + std::move(nextOperandOriginalMap); } void ForLoopGenerator::getOperationInCurrentAnchor( @@ -1076,10 +1073,8 @@ void ForLoopGenerator::getOperationInCurrentAnchor( } void ForLoopGenerator::replaceOperationsWithForLoopResult( - IRRewriter &rewrite, const ValueRange &forResults, const Block *forBlock, - const llvm::SmallVector &nextAnchorResults, - const std::queue &movingOperations, - DenseMap &forResultOrignalResultMap) { + IRRewriter &rewrite, const std::queue &movingOperations, + GenerateLoopHelper &loopHelperParam) { auto tmpQ(movingOperations); DenseSet operationOperands; while (!tmpQ.empty()) { @@ -1092,19 +1087,17 @@ void ForLoopGenerator::replaceOperationsWithForLoopResult( return operationOperands.contains(use.get()); }; for (auto [nxtForResult, nextLoopResult] : - zip(forResults, nextAnchorResults)) { - Value originalResult = forResultOrignalResultMap[nextLoopResult]; + zip(loopHelperParam.forResults, loopHelperParam.nextAnchorResults)) { + Value originalResult = + loopHelperParam.nextAnchorResultOrignalResultMap[nextLoopResult]; rewrite.replaceOpUsesWithIf(originalResult.getDefiningOp(), nxtForResult, replaceIfFn); } } -/// \param [out] nextAnchorArgsIdxMap -/// \param [out] nextAnchorArgs -/// \param [out] movingQueue -/// \param [in, out] originalOperandLoopArgsMap -/// \param [in, out] LoopArgsoriginalOperandMap +/// \param [in,out] nextLoopStateIdxMap +/// \param [in,out] nextAnchorArgs void ForLoopGenerator::movePreOpToCurrentAnchor( OpBuilder &b, DenseMap &nextLoopStateIdxMap, SmallVector &nextAnchorArgs, @@ -1119,19 +1112,10 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, &movingOperation); // 3. move opeartions to current for block - moveOperationsToCurrentForBody( - loopHelperParam.groupIdx, b, loopHelperParam.inductionVars, - loopHelperParam.currentLoopStateIdxMap, loopHelperParam.loopIterArgs, - loopHelperParam.originalOperandLoopArgsMap, movingOperation, - loopHelperParam.indiceLoopMap); + moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam); // 4. get next anchor args - getInitArgsToNextAnchor( - loopHelperParam.anchorIdx, loopHelperParam.groupIdx, - *loopHelperParam.candidateOps, loopHelperParam.loopIterArgs, - loopHelperParam.currentLoopStateIdxMap, nextLoopStateIdxMap, - nextAnchorArgs, loopHelperParam.originalOperandLoopArgsMap, - loopHelperParam.loopArgsOriginalOperandMap); + getInitArgsToNextAnchor(nextLoopStateIdxMap, nextAnchorArgs, loopHelperParam); // 5. move operations to moved queue while (!movingOperation.empty()) { @@ -1143,26 +1127,22 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( void ForLoopGenerator::movePostOpToCurrentAnchor( OpBuilder &b, GenerateLoopHelper &loopHelperParam) { - // 1. move post-op to current loop body std::queue movingOperations; + // 1. get post-op to current loop bod getOperationInCurrentAnchor(loopHelperParam.anchorIdx, *loopHelperParam.candidateOps, movingOperations); + // 2. rewrite operation as vectorize IR rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, &movingOperations); - moveOperationsToCurrentForBody( - loopHelperParam.anchorIdx, b, loopHelperParam.inductionVars, - loopHelperParam.currentLoopStateIdxMap, loopHelperParam.loopIterArgs, - loopHelperParam.originalOperandLoopArgsMap, movingOperations, - loopHelperParam.indiceLoopMap); + // 3. move opeartions to current for block + moveOperationsToCurrentForBody(b, movingOperations, loopHelperParam); - // 2. replace correct for loop result to post-op + // 4. replace correct for loop result to post-op IRRewriter rewriter(b); - replaceOperationsWithForLoopResult( - rewriter, loopHelperParam.forResults, loopHelperParam.forBlock, - loopHelperParam.nextAnchorResults, movingOperations, - loopHelperParam.nextAnchorResultOrignalResultMap); + replaceOperationsWithForLoopResult(rewriter, movingOperations, + loopHelperParam); - // 3. move operations to moved queue + // 5. move operations to moved queue while (!movingOperations.empty()) { loopHelperParam.movedOps->push(movingOperations.front()); movingOperations.pop(); @@ -1170,47 +1150,46 @@ void ForLoopGenerator::movePostOpToCurrentAnchor( } void ForLoopGenerator::generateLoopResults( - OpBuilder &b, const Location &loc, const size_t anchorIdx, - const size_t groupIdx, SmallVector &nextAnchorResults, - DenseMap &nextAnchorResultsIdxMap, const ValueRange &forResults, - const std::queue &movedOperation, - DenseMap &forResultOrignalResultMap, ValueRange loopState, - DenseMap ¤tOperandOriginMap, + OpBuilder &b, const Location &loc, GenerateLoopHelper &loopHelperParam, DenseMap &nextOperandIdxMap) { SmallVector results; DenseMap currentResultMap; - llvm::outs() << " move current operation to current for block\n"; - printQueue(movedOperation); - getResultInCurrentOps(anchorIdx, groupIdx, movedOperation, results, - nextAnchorResultsIdxMap, currentResultMap); + getResultInCurrentOps(loopHelperParam.anchorIdx, loopHelperParam.groupIdx, + *loopHelperParam.movedOps, results, + loopHelperParam.nextAnchorResultsIdxMap, + currentResultMap); llvm::MapVector> &groupResults = - getGroupOpResults()[groupIdx]; + getGroupOpResults()[loopHelperParam.groupIdx]; // check for yield results whether need to return to next anchor - for (auto [idx, forResult] : llvm::enumerate(nextAnchorResults)) { - Value originalResult = forResultOrignalResultMap[forResult]; + for (auto [idx, forResult] : + llvm::enumerate(loopHelperParam.nextAnchorResults)) { + Value originalResult = + loopHelperParam.nextAnchorResultOrignalResultMap[forResult]; if (groupResults.contains(originalResult)) { std::pair resultType = groupResults[originalResult]; - if (needReturnResult(resultType, anchorIdx)) { - results.emplace_back(forResults[idx]); - currentResultMap[forResults[idx]] = originalResult; + if (needReturnResult(resultType, loopHelperParam.anchorIdx)) { + results.emplace_back(loopHelperParam.forResults[idx]); + currentResultMap[loopHelperParam.forResults[idx]] = originalResult; } } } - nextAnchorResults.clear(); - nextAnchorResultsIdxMap.clear(); + loopHelperParam.nextAnchorResults.clear(); + loopHelperParam.nextAnchorResultsIdxMap.clear(); // reduction operation due to special process results size will be zero if (results.size() > 0) - for (Value x : loopState) { - nextAnchorResults.emplace_back(results[nextOperandIdxMap[x]]); - nextAnchorResultsIdxMap[results[nextOperandIdxMap[x]]] = - nextAnchorResults.size() - 1; + for (Value x : loopHelperParam.loopIterArgs) { + loopHelperParam.nextAnchorResults.emplace_back( + results[nextOperandIdxMap[x]]); + loopHelperParam.nextAnchorResultsIdxMap[results[nextOperandIdxMap[x]]] = + loopHelperParam.nextAnchorResults.size() - 1; } - forResultOrignalResultMap = std::move(currentResultMap); + loopHelperParam.nextAnchorResultOrignalResultMap = + std::move(currentResultMap); } scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( @@ -1298,13 +1277,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( movePostOpToCurrentAnchor(b, loopHelperParam); // 4. generate loop results - generateLoopResults(b, loc, currentAnchorId, loopHelperParam.groupIdx, - loopHelperParam.nextAnchorResults, - loopHelperParam.nextAnchorResultsIdxMap, - nxtFor->getResults(), *loopHelperParam.movedOps, - loopHelperParam.nextAnchorResultOrignalResultMap, - loopState, currentArgsOriginalMap, - nextAnchorArgsIdxMap); + generateLoopResults(b, loc, loopHelperParam, nextAnchorArgsIdxMap); // reduction must return accumulate if (loopHelperParam.orignalResultNextAnchorResultMap.contains( @@ -1353,11 +1326,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, &movingOperation); - moveOperationsToCurrentForBody( - loopHelperParam.groupIdx, b, loopHelperParam.inductionVars, - loopHelperParam.currentLoopStateIdxMap, loopState, - loopHelperParam.originalOperandLoopArgsMap, movingOperation, - loopHelperParam.indiceLoopMap); + moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam); loopHelperParam.movedOps = &movingOperation; loopHelperParam.candidateOps = &opQueue; @@ -1510,32 +1479,27 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( } scf::ForOp nxtFor; - DenseMap originalResultForResultMap; // 2. generate next for loop - if (rdCanonicalizer.hasLastDimReduction() or - loopHelperParam.anchorIdx < parallelAxis.size() - 1) { - loopHelperParam.anchorIdx += 1; + bool useParallelLoop = + rdCanonicalizer.hasLastDimReduction() or + loopHelperParam.anchorIdx < parallelAxis.size() - 1; + loopHelperParam.anchorIdx += 1; + if (useParallelLoop) { nxtFor = parallelAxisGenerateForLoop(b, indiceLoopMap, loopHelperParam); - } else if (parallelAxis.size() - 1 == loopHelperParam.anchorIdx) { - loopHelperParam.anchorIdx += 1; + } else { nxtFor = reductionAxisGenerateForLoop(b, 0, loopHelperParam); } loopHelperParam.anchorIdx -= 1; + // 3. move postOp to current body loopHelperParam.updateDataBeforePostOpMove( tmpArgs, currentLoopStateIdxMap, currentOriginalOperandMap, currentOperandOriginalMap, nxtFor->getResults(), nxtFor->getBlock(), movedQueue, loopHelperParam.anchorIdx); - // 3. move postOp to current body movePostOpToCurrentAnchor(b, loopHelperParam); // 4. generate loop results - generateLoopResults( - b, loc, loopHelperParam.anchorIdx, loopHelperParam.groupIdx, - loopHelperParam.nextAnchorResults, - loopHelperParam.nextAnchorResultsIdxMap, nxtFor->getResults(), - movedQueue, loopHelperParam.nextAnchorResultOrignalResultMap, - loopState, currentOperandOriginalMap, nextAnchorArgsIdxMap); + generateLoopResults(b, loc, loopHelperParam, nextAnchorArgsIdxMap); maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); } else if (loopHelperParam.anchorIdx == parallelAxis.size()) { @@ -2636,11 +2600,11 @@ void ForLoopGenerator::setOperationCorrectOperand( "Permuatation map must contains dim expr."); size_t dim; - if (auto d = dyn_cast(x)) + if (auto d = dyn_cast(x)) { dim = d.getPosition(); - - if (auto d = dyn_cast(x)) + } else if (auto d = dyn_cast(x)) { dim = d.getValue(); + } ShapedType tensorType = cast(op->getOperandTypes()[offset - 1]); @@ -2698,11 +2662,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( rewriteOperationAsVectorize(b, groupIdx, &movingOperation); // 3. move opeartions to current for block - moveOperationsToCurrentForBody( - groupIdx, b, loopHelper.inductionVars, - loopHelper.currentLoopStateIdxMap, loopState, - loopHelper.originalOperandLoopArgsMap, movingOperation, - loopHelper.indiceLoopMap); + moveOperationsToCurrentForBody(b, movingOperation, loopHelper); getResultInCurrentOps(loopHelper.anchorIdx, groupIdx, movingOperation, loopHelper.nextAnchorResults, @@ -2747,12 +2707,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( movedQueue, loopHelper.anchorIdx); movePostOpToCurrentAnchor(b, loopHelper); - generateLoopResults( - b, loc, loopHelper.anchorIdx, groupIdx, - loopHelper.nextAnchorResults, loopHelper.nextAnchorResultsIdxMap, - nxtFor->getResults(), movedQueue, - loopHelper.nextAnchorResultOrignalResultMap, loopState, - currentOperandOriginalMap, nextAnchorArgsIdxMap); + generateLoopResults(b, loc, loopHelper, nextAnchorArgsIdxMap); maybeYieldValue(b, loc, loopHelper.nextAnchorResults); } @@ -2878,9 +2833,8 @@ void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output, bcAxis.emplace_back(i); } } - if (bcAxis.empty()) { + if (bcAxis.empty()) bcAxis.emplace_back(-1); - } } void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { @@ -2889,6 +2843,7 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { [&](vector::MultiDimReductionOp multiReductionOp) { auto rdDimsRange = multiReductionOp.getReductionDims(); dataAxis.assign(rdDimsRange.begin(), rdDimsRange.end()); + return; }) .Case([&](vector::ShapeCastOp shapeCastOp) { auto srcType = shapeCastOp.getSourceVectorType(); @@ -2900,6 +2855,7 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { } else { shapeCastSourceAxis(dstShape, srcShape, dataAxis); } + return; }) .Case([&](vector::BroadcastOp broadcastOp) { auto srcType = broadcastOp.getSourceType(); @@ -2911,6 +2867,7 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { auto outputType = mlir::cast(dstType); getSrcBroadcastDim(inputType, outputType, dataAxis); } + return; }) .Case([&](vector::TransposeOp transposeOp) { auto perm = transposeOp.getPermutation(); @@ -2921,11 +2878,13 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { } start++; } + return; }) .Default([&](Operation *op) { // default is last axis dataAxis.emplace_back( cast(op->getResultTypes()[0]).getRank() - 1); + return; }); } diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index 3c12b30ac..e2791010e 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -25,30 +25,13 @@ #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/ExecutionEngine/Float16bits.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" -#include #include -#include -#include -#include -#include // #include "gc/Dialect/Microkernel/MicrokernelOps.h" namespace mlir { namespace gc { @@ -60,23 +43,16 @@ namespace { /// build a constant operation of index type Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); -/// set correct operand for the operation -void setOperationCorrectOperand( - Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, - DenseMap &originalOperandLoopArgsMap, - ArrayRef inductionVars, - DenseMap &opPermuationMap); + /// get operation read or write tensor mlir::FailureOr getOperationOperateTensor(Operation *op); +/// record hardware information struct HardWareInfo { bool favx512f = true; bool favx2 = true; }; -//===----------------------------------------------------------------------===// -// helper function -//===----------------------------------------------------------------------===// /// Using to avoid too many parameters in function struct GenerateLoopHelper { /// anchor id @@ -226,6 +202,13 @@ void GenerateLoopHelper::updateCurrentArgsStatus( loopArgsOriginalOperandMap = argsOriginalMap; } +/// set correct operand for the operation +void setOperationCorrectOperand( + Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, + DenseMap &originalOperandLoopArgsMap, + ArrayRef inductionVars, + DenseMap &opPermuationMap); + /// Vector type conversion helper class class TypeHelper { private: @@ -521,11 +504,10 @@ class CanonicalizerCommonUsedData : public TypeHelper { CanonicalizerCommonUsedData( VectorFusionStrategy &fusionStrategy, - llvm::SmallVector< - llvm::MapVector>, 8> + SmallVector>, 8> &groupOpResults, - llvm::SmallVector, 8> &groupOpInitArgs, - llvm::DenseMap &opPermuationMap) + SmallVector, 8> &groupOpInitArgs, + DenseMap &opPermuationMap) : fusionStrategy(fusionStrategy), groupOpResults(groupOpResults), groupOpInitArgs(groupOpInitArgs), opPermuationMap(opPermuationMap) {} virtual ~CanonicalizerCommonUsedData() noexcept {}; @@ -585,17 +567,17 @@ class CanonicalizerCommonUsedData : public TypeHelper { return multiRdCanonicalizers; } - llvm::SmallVector & + SmallVector & getBroadcastCanonicalizers() noexcept { return broadcastCanonicalizers; } - llvm::SmallVector & + SmallVector & getTransposeCanonicalizers() noexcept { return transposeCanonicalizers; } - llvm::SmallVector & + SmallVector & getShapeCastCanonicalizers() noexcept { return shapeCastCanonicalizers; } @@ -676,12 +658,9 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const Location &loc, ArrayRef dims, GenerateLoopHelper &loopGenerator); - void moveOperationsToCurrentForBody( - const size_t groupIdx, const OpBuilder &b, ArrayRef inductionVars, - const llvm::DenseMap &operandIdxMap, ValueRange loopState, - DenseMap &originalOperandLoopArgsMap, - std::queue &queue, - DenseMap> &indiceLoopMap); + void moveOperationsToCurrentForBody(const OpBuilder &b, + std::queue &queue, + GenerateLoopHelper &loopHelperParam); void setOperationCorrectOperand( Operation *op, ValueRange iterArgs, @@ -696,30 +675,17 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { SmallVector &results, DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap); - - /// todo: need to add a struct to remove so many parameters - void - getInitArgsToNextAnchor(const size_t anchorIdx, const size_t groupId, - const std::queue &nextOperations, - const ValueRange &loopState, - llvm::DenseMap ¤tLoopStateIdxMap, - llvm::DenseMap &nextAnchorArgsIdxMap, - llvm::SmallVector &nextAnchorArgs, - DenseMap &originalOperandLoopArgsMap, - DenseMap &loopArgsOriginalOperandMap); + /// get next anchor's iteration loop args + void getInitArgsToNextAnchor(llvm::DenseMap &nextAnchorArgsIdxMap, + llvm::SmallVector &nextAnchorArgs, + GenerateLoopHelper &loopHelperParam); void getOperationInCurrentAnchor(const size_t anchorIdx, std::queue &fromQueue, std::queue &toQueue); + /// get current loop operation result void generateLoopResults(OpBuilder &b, const Location &loc, - const size_t anchorIdx, const size_t groupIdx, - llvm::SmallVector &nextAnchorResults, - llvm::DenseMap &nextAnchorResultsIdxMap, - const ValueRange &forResults, - const std::queue &movedOperaiton, - DenseMap &forResultOrignalResultMap, - ValueRange loopState, - DenseMap ¤tOperandOriginMap, + GenerateLoopHelper &loopHelperParam, DenseMap &nextOperandIdxMap); /// todo: need to add a struct to remove so many parameters @@ -732,10 +698,8 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { GenerateLoopHelper &loopHelperParam); void replaceOperationsWithForLoopResult( - IRRewriter &rewrite, const ValueRange &forResults, const Block *forBlock, - const llvm::SmallVector &nextAnchorResults, - const std::queue &movingOperations, - DenseMap &forResultOrignalResultMap); + IRRewriter &rewrite, const std::queue &movingOperations, + GenerateLoopHelper &loopHelperParam); // multireduction forloop methods scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); /// Rearrange the current opIR to facilitate the generation of the correct From 03531010bb5ffe6837b55d6b65ff07be7629c9a6 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 9 Sep 2024 14:40:03 +0800 Subject: [PATCH 41/66] update --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 370 +++++++++--------- lib/gc/Transforms/TilingVector.h | 83 ++-- .../gc/Transforms/cpu-vetor-distribution.mlir | 60 +-- 3 files changed, 256 insertions(+), 257 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 8208cf877..9c770fc6e 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -152,9 +152,9 @@ bool isNotSupportOperation(Operation *op) { /// default will return the opeation result type mlir::FailureOr getOperationVectorType(Operation *op, bool isPrevOp = true) { - if (!op) { + if (not op) return failure(); - } + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; auto ret = TypeSwitch>(op) @@ -268,9 +268,11 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { } int TypeHelper::generateValidSteps(int steps, VectorType type) { - return type.getShape().back() >= steps - ? steps - : getNearestVectorStep(type.getShape().back()); + if (type.getShape().back() >= steps) + return steps; + int evenStep = getNearestVectorStep(type.getShape().back()); + auto typebits = type.getElementTypeBitWidth(); + return evenStep * typebits >= 128 ? evenStep : 1; } // Get the maximum number of current data types that a register can hold @@ -1192,6 +1194,17 @@ void ForLoopGenerator::generateLoopResults( std::move(currentResultMap); } +void updateLoopArgsData(Value val, Value originalVal, + SmallVector &argsArray, + DenseMap &anchorArgsIdxMap, + DenseMap &originalOperandLoopArgsMap, + DenseMap &loopArgsOriginalOperandMap) { + argsArray.emplace_back(val); + anchorArgsIdxMap[val] = argsArray.size() - 1; + loopArgsOriginalOperandMap[val] = originalVal; + originalOperandLoopArgsMap[originalVal] = val; +} + scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( OpBuilder &opBuilder, const size_t reductionIdx, GenerateLoopHelper &loopHelperParam) { @@ -1227,6 +1240,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( loopHelperParam.inductionVars.emplace_back(iv); size_t currentAnchorId = loopHelperParam.anchorIdx; SmallVector tmpArgs(loopState); + Value originalRetVal = multireductionOp->getResults()[0]; if (reductionIdx < reductionAxis.size() - 1) { @@ -1281,23 +1295,21 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( // reduction must return accumulate if (loopHelperParam.orignalResultNextAnchorResultMap.contains( - multireductionOp->getResults()[0])) { + originalRetVal)) { Value lastForResult = - loopHelperParam.orignalResultNextAnchorResultMap - [multireductionOp->getResults()[0]]; + loopHelperParam + .orignalResultNextAnchorResultMap[originalRetVal]; size_t retIdx = nextAnchorArgsIdxMap [loopHelperParam .nextAnchorResultOrignalResultMap[lastForResult]]; Value forRes = nxtFor->getResults()[retIdx]; // accumulate for loop iter args must be last, so we just put the // reduction result as the last result - loopHelperParam.nextAnchorResults.emplace_back(forRes); - loopHelperParam.nextAnchorResultsIdxMap[forRes] = - loopHelperParam.nextAnchorResults.size() - 1; - loopHelperParam.nextAnchorResultOrignalResultMap[forRes] = - multireductionOp->getResults()[0]; - loopHelperParam.orignalResultNextAnchorResultMap - [multireductionOp->getResults()[0]] = forRes; + updateLoopArgsData( + forRes, originalRetVal, loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + loopHelperParam.orignalResultNextAnchorResultMap, + loopHelperParam.nextAnchorResultOrignalResultMap); } maybeYieldValue(b, loc, loopHelperParam.nextAnchorResults); @@ -1348,12 +1360,11 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( movePostOpToCurrentAnchor(b, loopHelperParam); loopHelperParam.nextAnchorResults.clear(); - loopHelperParam.nextAnchorResults.emplace_back(reductionResult); - loopHelperParam.nextAnchorResultsIdxMap[reductionResult] = 0; - loopHelperParam.nextAnchorResultOrignalResultMap[reductionResult] = - multireductionOp->getResults()[0]; - loopHelperParam.orignalResultNextAnchorResultMap - [multireductionOp->getResults()[0]] = reductionResult; + updateLoopArgsData(reductionResult, originalRetVal, + loopHelperParam.nextAnchorResults, + loopHelperParam.nextAnchorResultsIdxMap, + loopHelperParam.orignalResultNextAnchorResultMap, + loopHelperParam.nextAnchorResultOrignalResultMap); getResultInCurrentOps( loopHelperParam.anchorIdx, loopHelperParam.groupIdx, movingOperation, loopHelperParam.nextAnchorResults, @@ -1366,12 +1377,56 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( return forOp; } +void ForLoopGenerator::ensureAccInParallelLoop( + GenerateLoopHelper &loopHelperParam, ArrayRef parallelAxis, + Value multiReductionAcc, DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + if (loopHelperParam.anchorIdx == parallelAxis.size() - 1) { + // Ensure accumalate expression appear in this parallel anchor + // position. If it not appear in current anchor, we must move it in + // here. + // 1. delete it in operation queue + // 2. move it in current movedqueue + DenseSet argsSet(nextAnchorArgs.begin(), nextAnchorArgs.end()); + std::queue checkAccQueue(*loopHelperParam.movedOps); + Value accInitVal; + while (!checkAccQueue.empty()) { + Operation *cur = checkAccQueue.front(); + checkAccQueue.pop(); + bool ok = false; + for (auto x : cur->getResults()) { + if (x == multiReductionAcc) { + accInitVal = x; + ok = true; + break; + } + } + if (ok) + break; + } + if (accInitVal) { + // we put initVal at last for loop args + if (!argsSet.contains(accInitVal)) { + nextAnchorArgs.emplace_back(accInitVal); + nextAnchorArgsIdxMap[accInitVal] = nextAnchorArgs.size() - 1; + loopHelperParam.loopArgsOriginalOperandMap[accInitVal] = + multiReductionAcc; + loopHelperParam.originalOperandLoopArgsMap[multiReductionAcc] = + accInitVal; + } + loopHelperParam.loopIterArgs = nextAnchorArgs; + loopHelperParam.nextAnchorResultsIdxMap = nextAnchorArgsIdxMap; + } else { + llvm::llvm_unreachable_internal("Wrong accumualte source value. Because " + "acc value must appear in here."); + } + } +} + /// Generate for loop for parallel axis of `vector.multi_reduction`. /// This function also call reduction axis for loop scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( - OpBuilder &opBuilder, - DenseMap> &indiceLoopMap, - GenerateLoopHelper &loopHelperParam) { + OpBuilder &opBuilder, GenerateLoopHelper &loopHelperParam) { MultiReductionCanonicalizer &rdCanonicalizer = getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; vector::MultiDimReductionOp &multiReductionOp = @@ -1384,10 +1439,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); size_t grpMaxStep = getFusionStrategy().getGroupMaxSteps()[loopHelperParam.groupIdx]; - size_t actualStep = (loopHelperParam.anchorIdx == parallelAxis.size() - 1 and - !rdCanonicalizer.getHasLastDimReduction()) - ? grpMaxStep - : 1; + size_t actualStep = + (loopHelperParam.anchorIdx == parallelAxis.size() - 1 ? grpMaxStep : 1); Value forSteps = makeIndexArithConstantOp(opBuilder, loc, actualStep); // last dim reduction need to a generate dim=16 loop for fused with pre-op @@ -1434,50 +1487,9 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( loopHelperParam); loopHelperParam.updateDataAfterPreOpMove(nextAnchorArgsIdxMap, nextAnchorArgs); - - if (loopHelperParam.anchorIdx == parallelAxis.size() - 1) { - // Ensure accumalate expression appear in this parallel anchor - // position. If it not appear in current anchor, we must move it in - // here. - // 1. delete it in operation queue - // 2. move it in current movedqueue - DenseSet argsSet(nextAnchorArgs.begin(), - nextAnchorArgs.end()); - std::queue checkAccQueue(movedQueue); - Value accInitVal; - while (!checkAccQueue.empty()) { - Operation *cur = checkAccQueue.front(); - checkAccQueue.pop(); - bool ok = false; - for (auto x : cur->getResults()) { - if (x == multiReductionAcc) { - accInitVal = x; - ok = true; - break; - } - } - if (ok) - break; - } - if (accInitVal) { - // we put initVal at last for loop args - if (!argsSet.contains(accInitVal)) { - nextAnchorArgs.emplace_back(accInitVal); - nextAnchorArgsIdxMap[accInitVal] = nextAnchorArgs.size() - 1; - loopHelperParam.loopArgsOriginalOperandMap[accInitVal] = - multiReductionAcc; - loopHelperParam.originalOperandLoopArgsMap[multiReductionAcc] = - accInitVal; - } - loopHelperParam.loopIterArgs = nextAnchorArgs; - loopHelperParam.nextAnchorResultsIdxMap = nextAnchorArgsIdxMap; - } else { - llvm::llvm_unreachable_internal( - "Wrong accumualte source value. Because " - "acc value must appear in here."); - } - } - + ensureAccInParallelLoop(loopHelperParam, parallelAxis, + multiReductionAcc, nextAnchorArgsIdxMap, + nextAnchorArgs); scf::ForOp nxtFor; // 2. generate next for loop bool useParallelLoop = @@ -1485,8 +1497,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( loopHelperParam.anchorIdx < parallelAxis.size() - 1; loopHelperParam.anchorIdx += 1; if (useParallelLoop) { - nxtFor = - parallelAxisGenerateForLoop(b, indiceLoopMap, loopHelperParam); + nxtFor = parallelAxisGenerateForLoop(b, loopHelperParam); } else { nxtFor = reductionAxisGenerateForLoop(b, 0, loopHelperParam); } @@ -1522,27 +1533,24 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( DenseMap localAnchorArgsIdxMap; DenseMap localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap; - SmallVector argsArray; - argsArray.emplace_back(accVal); - localAnchorArgsIdxMap[accVal] = 0; + updateLoopArgsData( + accVal, multiReductionAcc, argsArray, localAnchorArgsIdxMap, + localOriginalOperandLoopArgsMap, localLoopArgsOriginalOperandMap); + size_t accLoopStateIdx = loopHelperParam.currentLoopStateIdxMap [loopHelperParam .originalOperandLoopArgsMap[multiReductionAcc]]; - localLoopArgsOriginalOperandMap[accVal] = multiReductionAcc; - localOriginalOperandLoopArgsMap[multiReductionAcc] = accVal; - for (auto [idx, x] : llvm::enumerate(loopState)) { if (idx == accLoopStateIdx) continue; - - argsArray.emplace_back(x); - localAnchorArgsIdxMap[x] = argsArray.size() - 1; - Value originalValue = loopHelperParam.loopArgsOriginalOperandMap - [loopHelperParam.loopIterArgs[idx]]; - localOriginalOperandLoopArgsMap[originalValue] = x; - localLoopArgsOriginalOperandMap[x] = originalValue; + updateLoopArgsData(x, + loopHelperParam.loopArgsOriginalOperandMap + [loopHelperParam.loopIterArgs[idx]], + argsArray, localAnchorArgsIdxMap, + localOriginalOperandLoopArgsMap, + localLoopArgsOriginalOperandMap); } loopHelperParam.updateCurrentArgsStatus( localAnchorArgsIdxMap, argsArray, localOriginalOperandLoopArgsMap, @@ -1594,32 +1602,32 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( } scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( - OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, - const int tpSteps, const Location &loc, SmallVector &inductionVars, - ValueRange iterArgs, DenseMap &operandIdxMap, - DenseMap &originalOperandMap, Operation *successorWriteOp) { - auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; + OpBuilder &opBuilder, const int tpSteps, const Location &loc, + Operation *successorWriteOp, GenerateLoopHelper &loopHelperParam) { + auto &tpCanonicalizer = + getTransposeCanonicalizers()[loopHelperParam.groupIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; VectorType vtType = tpOp.getVector().getType(); size_t rank = vtType.getRank(); auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); - bool isTransposeDim = forDimIdx == tpCanonicalizer.getFirstTpIdx() or - forDimIdx == tpCanonicalizer.getSecondTpIdx(); + bool isTransposeDim = + loopHelperParam.anchorIdx == tpCanonicalizer.getFirstTpIdx() or + loopHelperParam.anchorIdx == tpCanonicalizer.getSecondTpIdx(); auto forSteps = makeIndexArithConstantOp(opBuilder, loc, isTransposeDim ? tpSteps : 1); - auto numIter = - makeIndexArithConstantOp(opBuilder, loc, vtType.getShape()[forDimIdx]); + auto numIter = makeIndexArithConstantOp( + opBuilder, loc, vtType.getShape()[loopHelperParam.anchorIdx]); VectorType kernelType = VectorType::get({tpSteps, tpSteps}, vtType.getElementType()); // generate transpose for loop return opBuilder.create( - loc, zero, numIter, forSteps, iterArgs, + loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { - inductionVars.emplace_back(iv); + loopHelperParam.inductionVars.emplace_back(iv); // inner most body of the loop - if (forDimIdx == rank - 1) { + if (loopHelperParam.anchorIdx == rank - 1) { // transfer read from source tensor Value source = tpOp->getOperand(0); auto readSourceOp = @@ -1636,27 +1644,29 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( loc, /*vectorType=*/kernelType, /*source=*/readSourceOp.getSource(), - /*indices=*/inductionVars, + /*indices=*/loopHelperParam.inductionVars, /*padding=*/padValue, /*inBounds=*/inBoundsVal); SmallVector perm{1, 0}; auto transposeOp = b.create( loc, transferReadOp->getResults()[0], perm); - SmallVector writeVars(inductionVars.begin(), - inductionVars.end()); + SmallVector writeVars(loopHelperParam.inductionVars.begin(), + loopHelperParam.inductionVars.end()); writeVars[tpCanonicalizer.getSecondTpIdx()] = - inductionVars[tpCanonicalizer.getFirstTpIdx()]; + loopHelperParam.inductionVars[tpCanonicalizer.getFirstTpIdx()]; writeVars[tpCanonicalizer.getFirstTpIdx()] = - inductionVars[tpCanonicalizer.getSecondTpIdx()]; + loopHelperParam.inductionVars[tpCanonicalizer.getSecondTpIdx()]; auto writeOp = b.create( loc, transposeOp->getResults()[0], loopState[0], writeVars, inBoundsVal); maybeYieldValue(b, loc, writeOp->getResults()); } else { // outter loop + loopHelperParam.anchorIdx += 1; + loopHelperParam.loopIterArgs = loopState; auto nxtFor = generateTransposeForLoopWithLastDim( - b, grpIdx, forDimIdx + 1, tpSteps, loc, inductionVars, loopState, - operandIdxMap, originalOperandMap, successorWriteOp); + b, tpSteps, loc, successorWriteOp, loopHelperParam); + loopHelperParam.anchorIdx -= 1; maybeYieldValue(b, loc, nxtFor->getResults()); } }); @@ -1702,15 +1712,16 @@ void ForLoopGenerator::rearrageMultiReductionIR( for (size_t i = 0; i < parallelAxis.size(); i++) { varLoopIdxMap[parallelAxis[i]] = i; } - for (size_t i = parallelAxis.size(); i < groupVector.getRank(); i++) { - varLoopIdxMap[reductionAxis[i - parallelAxis.size()]] = i; + size_t offset = rdCanonicalizer.hasLastDimReduction() ? 1 : 0; + for (size_t i = parallelAxis.size() + offset; + i < groupVector.getRank() + offset; i++) { + varLoopIdxMap[reductionAxis[i - parallelAxis.size() - offset]] = i; } while (!tmpSourceQ.empty()) { auto *curOp = tmpSourceQ.front(); tmpSourceQ.pop(); - if (isa(curOp)) { + if (isa(curOp)) getCurrentGroupIndiceLoopMap(indiceLoopMap, grpIdx, curOp, varLoopIdxMap); - } } // move accumulate related operation to operation first @@ -1766,9 +1777,9 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { getMultiRdCanonicalizers()[grpIdx]; OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); + loopHelper.indiceLoopMap = indiceLoopMap; - scf::ForOp forOp = - parallelAxisGenerateForLoop(opBuilder, indiceLoopMap, loopHelper); + scf::ForOp forOp = parallelAxisGenerateForLoop(opBuilder, loopHelper); replaceOpUsersWithForLoopResult(forOp, grpIdx, loopHelper.nextAnchorResults, loopHelper.nextAnchorResultsIdxMap, loopHelper.nextAnchorResultOrignalResultMap); @@ -1783,27 +1794,33 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { // generate simple data movement for loop scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( - OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, - const Location &loc, SmallVector &inductionVars, - const ValueRange &iterArgs, DenseMap &tpAxisMap) { - auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; + OpBuilder &opBuilder, const Location &loc, + DenseMap &tpAxisMap, GenerateLoopHelper &loopHelperParam) { + auto &tpCanonicalizer = + getTransposeCanonicalizers()[loopHelperParam.groupIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; VectorType vtType = tpOp.getSourceVectorType(); size_t rank = vtType.getRank(); auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); - auto forSteps = makeIndexArithConstantOp(opBuilder, loc, 1); - auto numIter = - makeIndexArithConstantOp(opBuilder, loc, vtType.getShape()[forDimIdx]); - VectorType kernelType = VectorType::get({1}, vtType.getElementType()); + size_t vecStep = tpCanonicalizer.transposeOnLastDim() + ? tpCanonicalizer.getVectorStep() + : 1; + auto forSteps = makeIndexArithConstantOp( + opBuilder, loc, loopHelperParam.anchorIdx == rank - 1 ? (vecStep) : 1); + auto numIter = makeIndexArithConstantOp( + opBuilder, loc, vtType.getShape()[loopHelperParam.anchorIdx]); + + SmallVector vecShapes(1, vecStep); + VectorType kernelType = VectorType::get(vecShapes, vtType.getElementType()); // generate transpose for loop return opBuilder.create( - loc, zero, numIter, forSteps, iterArgs, + loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { - inductionVars.emplace_back(iv); + loopHelperParam.inductionVars.emplace_back(iv); // inner most body of the loop - if (forDimIdx == rank - 1) { + if (loopHelperParam.anchorIdx == rank - 1) { // transfer read from source tensor Value source = tpOp->getOperand(0); auto readSourceOp = @@ -1812,6 +1829,7 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( for (Operation *x : tpOp->getUsers()) { if (isa(x)) { successorWriteOp = cast(x); + break; } } auto padValue = b.create( @@ -1820,14 +1838,15 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( SmallVector writeVars; size_t itrIdx = 0; while (itrIdx < rank) { - writeVars.emplace_back(inductionVars[tpAxisMap[itrIdx]]); + writeVars.emplace_back( + loopHelperParam.inductionVars[tpAxisMap[itrIdx]]); itrIdx++; } auto transferReadOp = b.create( loc, /*vectorType=*/kernelType, /*source=*/readSourceOp.getSource(), - /*indices=*/inductionVars, + /*indices=*/loopHelperParam.inductionVars, /*padding=*/padValue, /*inBounds=*/inBoundsVal); @@ -1839,9 +1858,11 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( maybeYieldValue(b, loc, writeOp->getResults()); } else { // outter loop - auto nxtFor = generateTransposeScalarDataMovement( - b, grpIdx, forDimIdx + 1, loc, inductionVars, loopState, - tpAxisMap); + loopHelperParam.anchorIdx += 1; + loopHelperParam.loopIterArgs = loopState; + auto nxtFor = generateTransposeScalarDataMovement(b, loc, tpAxisMap, + loopHelperParam); + loopHelperParam.anchorIdx -= 1; maybeYieldValue(b, loc, nxtFor->getResults()); } }); @@ -2129,57 +2150,9 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { DenseMap originalOperandMap, operandOriginalMap, resultIdxMap, forResultOrignalResultMap; SmallVector iterArgs; - GenerateLoopHelper loopHelper(grpIdx); + GenerateLoopHelper loopHelper(grpIdx, 0); prepareForLoopArgs(grpIdx, loopHelper); - operandIdxMap = loopHelper.currentLoopStateIdxMap; - originalOperandMap = loopHelper.originalOperandLoopArgsMap; - operandOriginalMap = loopHelper.loopArgsOriginalOperandMap; - iterArgs = loopHelper.loopIterArgs; - SmallVector inductionVars; - // TODO: need to process transpose on all one dim - // don't need to do the transpose - // if (tpCanonicalizer.isTransposeOnAllOneDim()) { - // removeOpInCurrentGroups(grpIdx, tpOp, - // tpOp->getOperand(0).getDefiningOp()); - - // // generate nested for loop - // SmallVector nextLoopResults; - // DenseMap resultIdxMap; - // SmallVector inductionVars; - // DenseMap forResultOrignalResultMap; - // Operation *firstOp = getFusionStrategy().getOpGroups()[grpIdx].front(); - // OpBuilder b(firstOp); - // VectorType groupVector = - // getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; - // ArrayRef shapes = groupVector.getShape(); - - // DenseMap> indiceLoopMap; - - // scf::ForOp forOp = constructNestedForOp( - // 0, grpIdx, b, firstOp->getLoc(), iterArgs, shapes, inductionVars, - // operandIdxMap, originalOperandMap, operandOriginalMap, - // nextLoopResults, resultIdxMap, forResultOrignalResultMap, - // indiceLoopMap); - - // forOp->dump(); - // DenseSet forOpChildOps; - // forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); - // auto replaceIfFn = [&](OpOperand &use) { - // return not forOpChildOps.contains(use.getOwner()); - // }; - // for (auto x : nextLoopResults) { - // auto originalResult = forResultOrignalResultMap[x]; - // rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), - // forOp->getResults()[resultIdxMap[x]], - // replaceIfFn); - // rectifyGroupOperands(grpIdx, originalResult, - // forOp->getResults()[resultIdxMap[x]]); - // } - // // clear current group operation - // clearCurrentOperationGroup(grpIdx); - // return forOp; - // } OpBuilder b(tpOp); int tpStep = TransposeCanonicalizer::TRANSPOSE_KERNEL::KERNEL_16X16; // only contains last dim can use fast transpose algorithm @@ -2187,8 +2160,7 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { tpCanonicalizer.getSecondTpIdx() == (rank - 1)) and isTwoDTranspose) { scf::ForOp forOp = generateTransposeForLoopWithLastDim( - b, grpIdx, 0, tpStep, tpOp.getLoc(), inductionVars, iterArgs, - operandIdxMap, originalOperandMap, successorWriteOp); + b, tpStep, tpOp.getLoc(), successorWriteOp, loopHelper); rewriter.replaceOp(successorWriteOp, forOp); // clear current group operation @@ -2202,8 +2174,8 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { itrIdx++; } // scalar data movement - scf::ForOp forOp = generateTransposeScalarDataMovement( - b, grpIdx, 0, tpOp.getLoc(), inductionVars, iterArgs, tpAxisMap); + scf::ForOp forOp = generateTransposeScalarDataMovement(b, tpOp.getLoc(), + tpAxisMap, loopHelper); rewriter.replaceOp(successorWriteOp, forOp); clearCurrentOperationGroup(grpIdx); @@ -2336,6 +2308,20 @@ bool TransposeCanonicalizer::isTwoDTranspose() { return diffCount == 2; } +bool TransposeCanonicalizer::transposeOnLastDim() { + ArrayRef permutation = getCandidateOps()[0].getPermutation(); + size_t rank = permutation.size(); + if (permutation[rank - 1] != rank - 1) + return false; + + VectorType vtType = getCandidateOps()[0].getResultVectorType(); + + if (vtType.getShape()[rank - 1] % getVectorStep() != 0) + return false; + + return true; +} + bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { vector::ShapeCastOp &shapeCastOp = getCandidateOps()[0]; VectorType sourceType = shapeCastOp.getSourceVectorType(); @@ -2379,8 +2365,9 @@ bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { return set.contains(largeRankType.getRank() - 1); } -template void addDummyInit(SmallVector &canonicalizer) { - canonicalizer.emplace_back(T({})); +template +void addDummyInit(SmallVector &canonicalizer, size_t steps = 1) { + canonicalizer.emplace_back(T({}, steps)); }; void CanonicalizerVectorOperation::clearSpecialOperationCanonicalizers() { @@ -2390,19 +2377,19 @@ void CanonicalizerVectorOperation::clearSpecialOperationCanonicalizers() { getShapeCastCanonicalizers().clear(); } -void CanonicalizerVectorOperation::dummyInitSpecialOperation() { - addDummyInit(getMultiRdCanonicalizers()); - addDummyInit(getBroadcastCanonicalizers()); - addDummyInit(getTransposeCanonicalizers()); - addDummyInit(getShapeCastCanonicalizers()); +void CanonicalizerVectorOperation::dummyInitSpecialOperation(size_t steps) { + addDummyInit(getMultiRdCanonicalizers(), steps); + addDummyInit(getBroadcastCanonicalizers(), steps); + addDummyInit(getTransposeCanonicalizers(), steps); + addDummyInit(getShapeCastCanonicalizers(), steps); } void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { clearSpecialOperationCanonicalizers(); SmallVector, 8> &opGroups = getFusionStrategy().getOpGroups(); - for (auto &grp : opGroups) { - dummyInitSpecialOperation(); + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + dummyInitSpecialOperation(getFusionStrategy().getGroupMaxSteps()[idx]); if (grp.empty()) continue; @@ -2530,7 +2517,6 @@ void CanonicalizerVectorOperation::run() { for (size_t idx = 0; idx < fusionStrategy.getOpGroups().size(); ++idx) { generateGroupOpVectorizedIR(idx); } - func->dump(); // 3. Some IR cleanup work DominanceInfo domInfo; eliminateCommonSubExpressions(rewriter, domInfo, func); @@ -3456,7 +3442,7 @@ void VectorOperationAnalyzer::analysisGroupMaxSteps() { auto calculateOpSteps = [&](Type type) { auto opType = dyn_cast(type); if (opType) - steps = std::min(steps, (uint32_t)getDataTypeMAXSIMDLength(opType)); + steps = std::min(steps, (uint32_t)getDataTypeValidSteps(opType)); }; while (!tmpQueue.empty()) { auto op = tmpQueue.front(); diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index e2791010e..3cb8d26fb 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -327,9 +327,10 @@ class VectorFusionStrategy : public TypeHelper { /// and we directly convert the operations into physical register sizes. enum CanonicalizerKind { OperationsGroup, Operations }; -template class SpecialOperationCanonicalizer { +template class SpecialOperationCanonicalizer : virtual TypeHelper { private: - llvm::SmallVector candidateRdOps; + SmallVector candidateRdOps; + size_t vectorStep = 1; public: enum class SpecialOperationKind { @@ -344,13 +345,18 @@ template class SpecialOperationCanonicalizer { public: SpecialOperationCanonicalizer() = default; - SpecialOperationCanonicalizer(const llvm::SmallVector &candidateRdOps, + SpecialOperationCanonicalizer(const SmallVector &candidateRdOps, SpecialOperationKind kind) : candidateRdOps(candidateRdOps), kind(kind) {} + SpecialOperationCanonicalizer(const SmallVector &candidateRdOps, + SpecialOperationKind kind, size_t step) + : candidateRdOps(candidateRdOps), vectorStep(step), kind(kind) {} llvm::SmallVector &getCandidateOps(); virtual ~SpecialOperationCanonicalizer() {} virtual void prepareSpecialOperationInfo() = 0; - SpecialOperationKind getKind() { return kind; } + SpecialOperationKind getKind() noexcept { return kind; } + void setVectorStep(size_t step) noexcept { vectorStep = step; } + size_t getVectorStep() noexcept { return vectorStep; } }; enum class MultiReduceOpAxisKind { Reduction, Parallel }; @@ -370,9 +376,10 @@ class MultiReductionCanonicalizer public: MultiReductionCanonicalizer( - const llvm::SmallVector &candidateRdOps) + const SmallVector &candidateRdOps, + size_t steps = 1) : SpecialOperationCanonicalizer( - candidateRdOps, SpecialOperationKind::OP_MultiDimReduction) { + candidateRdOps, SpecialOperationKind::OP_MultiDimReduction, steps) { isStandaloneOp = candidateRdOps.size() == 1; prepareSpecialOperationInfo(); }; @@ -420,9 +427,10 @@ class BroadcastCanonicalizer private: public: BroadcastCanonicalizer( - const llvm::SmallVector &candidateBcOps) + const llvm::SmallVector &candidateBcOps, + size_t steps = 1) : SpecialOperationCanonicalizer( - candidateBcOps, SpecialOperationKind::OP_Broadcast) {}; + candidateBcOps, SpecialOperationKind::OP_Broadcast, steps) {}; virtual ~BroadcastCanonicalizer() noexcept {} void prepareSpecialOperationInfo() override {} static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -437,9 +445,10 @@ class TransposeCanonicalizer public: TransposeCanonicalizer( - const llvm::SmallVector &candidateTpOps) + const llvm::SmallVector &candidateTpOps, + size_t steps = 1) : SpecialOperationCanonicalizer( - candidateTpOps, SpecialOperationKind::OP_Transpose) {}; + candidateTpOps, SpecialOperationKind::OP_Transpose, steps) {}; virtual ~TransposeCanonicalizer() noexcept {} void prepareSpecialOperationInfo() override; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -453,6 +462,7 @@ class TransposeCanonicalizer size_t getSecondTpIdx() noexcept { return secondTpIdx; } bool isTwoDTranspose(); bool isTransposeOnAllOneDim(); + bool transposeOnLastDim(); }; class ShapeCastCanonicalizer @@ -460,9 +470,10 @@ class ShapeCastCanonicalizer private: public: ShapeCastCanonicalizer( - const SmallVector &candidateScOps) + const SmallVector &candidateScOps, + size_t steps = 1) : SpecialOperationCanonicalizer( - candidateScOps, SpecialOperationKind::OP_ShapeCast) {}; + candidateScOps, SpecialOperationKind::OP_ShapeCast, steps) {}; virtual ~ShapeCastCanonicalizer() {} void prepareSpecialOperationInfo() override {} static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -487,15 +498,15 @@ class CanonicalizerCommonUsedData : public TypeHelper { /// analysis the operation's operands and results SmallVector>, 8> groupOpResults; - llvm::SmallVector, 8> groupOpInitArgs; + SmallVector, 8> groupOpInitArgs; // store read and write operations permutation maps in order to convenient // to replace loop induction var - llvm::DenseMap opPermuationMap; - llvm::SmallVector multiRdCanonicalizers; - llvm::SmallVector broadcastCanonicalizers; - llvm::SmallVector transposeCanonicalizers; - llvm::SmallVector shapeCastCanonicalizers; + DenseMap opPermuationMap; + SmallVector multiRdCanonicalizers; + SmallVector broadcastCanonicalizers; + SmallVector transposeCanonicalizers; + SmallVector shapeCastCanonicalizers; public: CanonicalizerCommonUsedData() = default; @@ -506,7 +517,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { VectorFusionStrategy &fusionStrategy, SmallVector>, 8> &groupOpResults, - SmallVector, 8> &groupOpInitArgs, + SmallVector, 8> &groupOpInitArgs, DenseMap &opPermuationMap) : fusionStrategy(fusionStrategy), groupOpResults(groupOpResults), groupOpInitArgs(groupOpInitArgs), opPermuationMap(opPermuationMap) {} @@ -537,8 +548,8 @@ class CanonicalizerCommonUsedData : public TypeHelper { groupOpResults = std::move(results); } - void setGroupOpIterArgs( - const llvm::SmallVector, 8> &initArgs) { + void + setGroupOpIterArgs(const SmallVector, 8> &initArgs) { groupOpInitArgs = std::move(initArgs); } @@ -712,29 +723,31 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const size_t reductionIdx, GenerateLoopHelper &loopHelperParam); - scf::ForOp parallelAxisGenerateForLoop( - OpBuilder &opBuilder, - DenseMap> &indiceLoopMap, - GenerateLoopHelper &loopHelperParam); + scf::ForOp parallelAxisGenerateForLoop(OpBuilder &opBuilder, + GenerateLoopHelper &loopHelperParam); + + void ensureAccInParallelLoop(GenerateLoopHelper &loopHelperParam, + ArrayRef parallelAxis, + Value multiReductionAcc, + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs); vector::TransferReadOp cloneReductionTransferRead( Value &source, OpBuilder &b, IRMapping &readMap, const llvm::SmallVector ¶llelAxis, - llvm::SmallVector &inductionVars, bool lastDimReduction, + SmallVector &inductionVars, bool lastDimReduction, MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel); /// generate for loop for transpose operation scf::ForOp generateTransposeForLoop(const size_t groupId); scf::ForOp generateTransposeForLoopWithLastDim( - OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, - const int tpSteps, const Location &loc, SmallVector &inductionVars, - ValueRange iterArgs, DenseMap &operandIdxMap, - DenseMap &originalOperandMap, Operation *successorWriteOp); + OpBuilder &opBuilder, const int tpSteps, const Location &loc, + Operation *successorWriteOp, GenerateLoopHelper &loopHelperParam); - scf::ForOp generateTransposeScalarDataMovement( - OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, - const Location &loc, SmallVector &inductionVars, - const ValueRange &iterArgs, DenseMap &tpAxisMap); + scf::ForOp + generateTransposeScalarDataMovement(OpBuilder &opBuilder, const Location &loc, + DenseMap &tpAxisMap, + GenerateLoopHelper &loopHelperParam); // shapecast scf::ForOp generateShapeCastForLoop(const size_t grpIdx); @@ -835,7 +848,7 @@ class CanonicalizerVectorOperation : virtual public ForLoopGenerator, // void canonicalizeSpecialOperation(); void clearSpecialOperationCanonicalizers(); - void dummyInitSpecialOperation(); + void dummyInitSpecialOperation(size_t steps); void initSpeicalOperationCanonicalizers(); void run(); diff --git a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir index 51e6ef721..081e13af8 100644 --- a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir +++ b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir @@ -232,8 +232,8 @@ func.func @fuse_mlp_test4(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32x // CHECK: scf.forall // CHECK-COUNT-6: scf.for // CHECK-COUNT-4: scf.for -// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x1x16x32xbf16>, vector<1xbf16> -// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xbf16>, tensor<32x16x1x32xbf16> +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x1x16x32xbf16>, vector<32xbf16> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x16x1x32xbf16> // CHECK: %[[FILL0:.*]] = linalg.fill // CHECK-COUNT-3: scf.for // CHECK: %[[APPLY0:.*]] = affine.apply @@ -362,9 +362,9 @@ func.func @elem_pack_transpose_inner_dims_test5(%arg0: tensor<128x256xi32>, %des // CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY2]]) -> (tensor<16x4x32x16xi32>) // CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x4x32x16xi32>) // CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<16x4x32x16xi32>) -// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<16x4x32x16xi32>) -// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<1xi32> -// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xi32>, tensor<16x4x32x16xi32> +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<16x4x32x16xi32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<4x32x16x16xi32>, vector<16xi32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xi32>, tensor<16x4x32x16xi32> #map6 = affine_map<(d0, d1) -> (d0, d1)> func.func @elem_pack_transpose_outer_dims_test6(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{ %init = tensor.empty() : tensor<128x256xi32> @@ -459,9 +459,9 @@ func.func @elem_pack_transpose_inner_and_outer_dims_test7(%arg0: tensor<128x256x // CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C56]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<1x2x56x57x32xf32>) // CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C57]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<1x2x56x57x32xf32>) // CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (tensor<1x2x56x57x32xf32>) -// CHECK: scf.for %[[arg10:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg11:.*]] = %[[arg9]]) -> (tensor<1x2x56x57x32xf32>) -// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<1x56x57x2x32xf32>, vector<1xf32> -// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<1xf32>, tensor<1x2x56x57x32xf32> +// CHECK: scf.for %[[arg10:.*]] = %[[C0]] to %[[C32]] step %[[C16]] iter_args(%[[arg11:.*]] = %[[arg9]]) -> (tensor<1x2x56x57x32xf32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<1x56x57x2x32xf32>, vector<16xf32> +// CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<1x2x56x57x32xf32> #map8 = affine_map<(d0, d1, d2, d3) -> (d3)> #map9 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func.func @elem_pack_transpose_inner_and_outer_dims2_test8(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> { @@ -570,11 +570,11 @@ func.func @reduce_fusePostOp_test11(%input: tensor<16x32x64xf32>, // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index // CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %arg1) -> (tensor<16x32xf32>) -// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x32xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x32xf32>) // CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST_0]] {in_bounds = [true]} : tensor<16x32xf32>, vector<16xf32> // CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ0]]) -> (vector<16xf32>) // CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[CST]]) -> (vector<16xf32>) -// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[arg4]], %[[arg6]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[arg4]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ1]] : vector<16xf32> // CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[arg9]] : vector<16xf32> // CHECK: %[[REDUCTION:.*]] = vector.reduction , {{.*}} : vector<16xf32> into f32 @@ -597,23 +597,23 @@ func.func @reduce_fuse_test12(%input: tensor<16x32x64xf32>, func.return %1 : tensor<16x32xf32> } -// func.func @pad_single_test13(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x64x58x58xf32> { -// %cst = arith.constant 0.000000e+00 : f32 -// %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] { -// ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): -// tensor.yield %cst : f32 -// } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32> -// return %padded : tensor<1x64x58x58xf32> -// } - - -// func.func @pad_valid_pack_propagation(%arg0: tensor<1x64x56x56xf32>) -> tensor<1x2x58x58x32xf32> { -// %cst = arith.constant 0.000000e+00 : f32 -// %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] { -// ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): -// tensor.yield %cst : f32 -// } : tensor<1x64x56x56xf32> to tensor<1x64x58x58xf32> -// %0 = tensor.empty() : tensor<1x2x58x58x32xf32> -// %1 = tensor.pack %padded inner_dims_pos = [1] inner_tiles = [32] into %0 : tensor<1x64x58x58xf32> -> tensor<1x2x58x58x32xf32> -// return %1 : tensor<1x2x58x58x32xf32> -// } +// CHECK-LABEL: func @add_small_tensor_test13 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[TENSOR0:.*]] = tensor.empty() : tensor<2xf32> +// CHECK: scf.for +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<2xf32>, vector<1xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<2xf32>, vector<1xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<1xf32> +// CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], %[[CST]] : vector<1xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<1xf32>, tensor<2xf32> +func.func @add_small_tensor_test13(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = tensor.empty() : tensor<2xf32> + %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> + %1 = linalg.add ins(%arg0, %arg1 : tensor<2xf32>, tensor<2xf32>) outs(%0: tensor<2xf32>) -> tensor<2xf32> + %2 = linalg.max ins(%1, %cst : tensor<2xf32>, tensor<2xf32>) outs(%0: tensor<2xf32>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} From 10e73f489507abeb6ebe5eba66b3056cb3ee01d4 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 9 Sep 2024 15:25:02 +0800 Subject: [PATCH 42/66] test ci --- lib/gc/Transforms/Pipeline.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 5c69aec7c..b4a46d231 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -10,7 +10,6 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -78,8 +77,6 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // scf + arith + math + vector + tensor + linalg.brgemm void populateVectorPasses(mlir::OpPassManager &pm) { - - pm.addNestedPass(createLowerToTileVector()); // Do promotion for math / arith ops pm.addNestedPass(math::createMathLegalizeToF32()); // sourceTypeStrs can be extended @@ -92,8 +89,7 @@ void populateVectorPasses(mlir::OpPassManager &pm) { // Bf16 cast elimilation pass pm.addNestedPass(mlir::createCanonicalizerPass()); // oneDNN graph spec - // pm.addNestedPass(arith::createArithExpandOpsPass()); - pm.addNestedPass(createCPUPhysicalRegisterPass()); + pm.addNestedPass(arith::createArithExpandOpsPass()); // todo: lower to physical vector pass, device dependent pass populateCleanUpPasses(pm); } From eec389f5798116ba21e347aa0be262f58b3eff1f Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 9 Sep 2024 16:08:08 +0800 Subject: [PATCH 43/66] fix clang-format --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 157 ++++++++++++------ lib/gc/Transforms/TilingVector.h | 100 +++-------- 2 files changed, 126 insertions(+), 131 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 9c770fc6e..03e151318 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -86,11 +86,8 @@ template , T>> static bool isOperationsHasDefUseRelation(Operation *op1, Operation *op2) { - for (Value opd : op2->getOperands()) - if (opd.getDefiningOp() == op1) - return true; - - return false; + return llvm::any_of(op2->getOperands(), + [&op1](Value opd) { return opd.getDefiningOp() == op1; }); } /// Get the index position of the first element that is true static size_t getFirstTrueIndex(ArrayRef ararys) { @@ -221,14 +218,18 @@ bool hasDynamicShape(Operation *op) { return false; }; // Check operands data type. - for (auto x : op->getOperands()) - if (isDynamicShapedType(x)) - return true; + if (llvm::any_of(op->getOperands(), [&isDynamicShapedType](Value x) { + return isDynamicShapedType(x); + })) { + return true; + } // Check results data type. - for (auto x : op->getResults()) - if (isDynamicShapedType(x)) - return true; + if (llvm::any_of(op->getResults(), [&isDynamicShapedType](OpResult x) { + return isDynamicShapedType(x); + })) { + return true; + } return false; } @@ -267,6 +268,77 @@ bool isLastDim(const AffineExpr &expr, const size_t rank) { dyn_cast(expr).getPosition() == rank - 1; } +void GenerateLoopHelper::setNextAnchorArgs( + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + currentLoopStateIdxMap = nextAnchorArgsIdxMap; + loopIterArgs = nextAnchorArgs; +} + +void GenerateLoopHelper::clearNextAnchorResults() { + nextAnchorResults.clear(); + nextAnchorResultsIdxMap.clear(); + nextAnchorResultOrignalResultMap.clear(); +} + +void GenerateLoopHelper::setAnchorId(size_t anchorId) noexcept { + anchorIdx = anchorId; +} + +void GenerateLoopHelper::updateDataBeforePreOpMove( + ArrayRef loopState, std::queue &candidateQueue, + std::queue &movedQueue) { + loopIterArgs = loopState; + candidateOps = &candidateQueue; + movedOps = &movedQueue; +} + +void GenerateLoopHelper::updateDataAfterPreOpMove( + DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + setNextAnchorArgs(nextAnchorArgsIdxMap, nextAnchorArgs); +} + +void GenerateLoopHelper::updateDataBeforePostOpMove( + ArrayRef iterArgs, DenseMap ¤tLoopStateIdxMap, + DenseMap ¤toriginalArgsMap, + DenseMap ¤tArgsOriginalMap, ValueRange forResults, + Block *forBlock, std::queue &movedQueue, size_t anchorId) { + this->originalOperandLoopArgsMap = currentoriginalArgsMap; + this->loopArgsOriginalOperandMap = currentArgsOriginalMap; + this->forResults = forResults; + this->forBlock = forBlock; + this->anchorIdx = anchorId; + this->currentLoopStateIdxMap = currentLoopStateIdxMap; + this->loopIterArgs = iterArgs; + this->movedOps = &movedQueue; +} + +void GenerateLoopHelper::updateDataAfterPostOpMove( + size_t anchorId, DenseMap &nextAnchorArgsIdxMap, + SmallVector &nextAnchorArgs) { + setAnchorId(anchorId); + setNextAnchorArgs(nextAnchorArgsIdxMap, nextAnchorArgs); +} + +void GenerateLoopHelper::setNextAnchorResults( + SmallVector ¤tAnchorResults, + DenseMap ¤tResultMap, + DenseMap ¤tResultIdxMap) { + nextAnchorResults = std::move(currentAnchorResults); + nextAnchorResultOrignalResultMap = std::move(currentResultMap); + nextAnchorResultsIdxMap = std::move(currentResultIdxMap); +} + +void GenerateLoopHelper::updateCurrentArgsStatus( + DenseMap ¤tArgsIdxMap, SmallVector ¤tArgs, + DenseMap &originalArgsMap, + DenseMap &argsOriginalMap) { + setNextAnchorArgs(currentArgsIdxMap, currentArgs); + originalOperandLoopArgsMap = originalArgsMap; + loopArgsOriginalOperandMap = argsOriginalMap; +} + int TypeHelper::generateValidSteps(int steps, VectorType type) { if (type.getShape().back() >= steps) return steps; @@ -393,8 +465,8 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loop_step) { /// \param retType resuilt return type bool needReturnResult(std::pair &retType, size_t anchorIdx) { - return !(retType.first == ReturnTypeKind::RT_InGroup and - retType.second >= anchorIdx); + return retType.first != ReturnTypeKind::RT_InGroup or + retType.second < anchorIdx; } union Float32Bits { @@ -878,22 +950,6 @@ void getReductionInitAttr(vector::MultiDimReductionOp &multiReductionOp, getInitValForReduce(multiReductionOp.getKind(), vecType)); } -void classifySourceRelatedOps(std::queue &accRelatedOps, - std::queue &sourceRelatedOps, - Operation *srcOp, - std::queue &prevOps) { - DenseSet srcOps; - getOpSourceOps(srcOp, srcOps); - while (!prevOps.empty()) { - auto op = prevOps.front(); - prevOps.pop(); - if (isSrcRelated(srcOps, op) or op == srcOp) - sourceRelatedOps.push(op); - else - accRelatedOps.push(op); - } -} - /// get multi_reduction operation accumulate value source related operations /// \param srcOp accumulate value source operation void classifyAccRelatedOps(std::queue &accRelatedOps, @@ -980,7 +1036,7 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( void ForLoopGenerator::getResultInCurrentOps( const size_t anchorIdx, const size_t groupId, - const std::queue ops, SmallVector &results, + const std::queue &ops, SmallVector &results, DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap) { auto tmpQ(ops); @@ -1182,7 +1238,7 @@ void ForLoopGenerator::generateLoopResults( loopHelperParam.nextAnchorResults.clear(); loopHelperParam.nextAnchorResultsIdxMap.clear(); // reduction operation due to special process results size will be zero - if (results.size() > 0) + if (not results.empty()) for (Value x : loopHelperParam.loopIterArgs) { loopHelperParam.nextAnchorResults.emplace_back( results[nextOperandIdxMap[x]]); @@ -2031,7 +2087,7 @@ void ForLoopGenerator::rectifyReadOperationIndice( // currently only broadcast (fuse as transfer_read) will move into more inner // loop if (readTensorType.getRank() - 1 >= - getFusionStrategy().getOpAnchorPos()[*originalReadOp]) + (int64_t)getFusionStrategy().getOpAnchorPos()[*originalReadOp]) return; int64_t itrIdx = loopType.getRank() - 1; @@ -2062,7 +2118,7 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { SmallVector successorWriteOps; for (Operation *x : scOp->getUsers()) if (isa(x) and opIndexMap.contains(x) and - opIndexMap[x] == opIndexMap[x]) + opIndexMap[x] == opIndexMap[scOp]) successorWriteOps.emplace_back(cast(x)); for (auto successorWriteOp : successorWriteOps) @@ -2108,7 +2164,7 @@ void ForLoopGenerator::getCurrentGroupIndiceLoopMap( DenseMap forIdxMap; VectorType groupVector = getFusionStrategy().getGroupBiggestRankVectorType()[groupId]; - for (size_t i = 0; i < groupVector.getRank(); i++) { + for (size_t i = 0; (int64_t)i < groupVector.getRank(); i++) { forIdxMap[i] = i; } indiceLoopMap[op] = forIdxMap; @@ -2245,8 +2301,6 @@ void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { } }; -void TransposeCanonicalizer::prepareSpecialOperationInfo() {} - bool TransposeCanonicalizer::isTransposeOnAllOneDim() { vector::TransposeOp tpOp = getCandidateOps()[0]; ArrayRef permutation = tpOp.getPermutation(); @@ -2311,7 +2365,7 @@ bool TransposeCanonicalizer::isTwoDTranspose() { bool TransposeCanonicalizer::transposeOnLastDim() { ArrayRef permutation = getCandidateOps()[0].getPermutation(); size_t rank = permutation.size(); - if (permutation[rank - 1] != rank - 1) + if (permutation[rank - 1] != (int64_t)rank - 1) return false; VectorType vtType = getCandidateOps()[0].getResultVectorType(); @@ -2334,12 +2388,12 @@ bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { // Map the index of the larger rank shape to the index of the smaller rank // shape. DenseMap> shapeIdxMap; - for (size_t i = 0; i < smallRankType.getRank(); i++) - shapeIdxMap[i] = std::move(SmallVector()); + for (size_t i = 0; (int64_t)i < smallRankType.getRank(); i++) + shapeIdxMap[i] = SmallVector(); - size_t itrIdx = 0; + int64_t itrIdx = 0; while (itrIdx < smallRankType.getRank()) { - size_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; + int64_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; assert(endShape < largeRankType.getRank() and endShape >= 0 && "Invalid endShape"); // skip non corresponding axis @@ -2423,7 +2477,7 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { template void CanonicalizerVectorOperation::processSpecialOperation( - T &canonicalizers, std::function generateFunc) { + T &canonicalizers, const std::function &generateFunc) { for (auto [groupId, canonicalizer] : llvm::enumerate(canonicalizers)) { SmallVector &ops = canonicalizer.getCandidateOps(); if (!ops.empty()) @@ -2585,7 +2639,7 @@ void ForLoopGenerator::setOperationCorrectOperand( llvm::llvm_unreachable_internal( "Permuatation map must contains dim expr."); - size_t dim; + size_t dim = 0; if (auto d = dyn_cast(x)) { dim = d.getPosition(); } else if (auto d = dyn_cast(x)) { @@ -2594,7 +2648,7 @@ void ForLoopGenerator::setOperationCorrectOperand( ShapedType tensorType = cast(op->getOperandTypes()[offset - 1]); - size_t varIdx = dim; + int64_t varIdx = dim; if (tensorType.getRank() > (int64_t)inductionVars.size()) { int64_t tensorOffset = tensorType.getRank() - inductionVars.size(); if (dim < tensorOffset) @@ -2877,11 +2931,8 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { static inline bool hasSameAxis(ArrayRef dims1, ArrayRef dims2) { DenseSet checkSet(dims2.begin(), dims2.end()); - for (auto x : dims1) - if (checkSet.contains(x)) - return true; - - return false; + return llvm::any_of(dims1, + [&checkSet](int64_t x) { return checkSet.contains(x); }); } /// whether two operation has data dependency @@ -3031,16 +3082,16 @@ bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { return false; } -void VectorFusionStrategy::updateGroupBitgestVectorType(VectorType vectorType) { +void VectorFusionStrategy::updateGroupBigestVectorType(VectorType vectorType) { int64_t rank = vectorType.getRank(); llvm::SmallDenseMap &groupVectorType = getGroupBiggestRankVectorType(); if (groupVectorType.contains(opGroups.size() - 1)) { VectorType bigestType = groupVectorType[opGroups.size() - 1]; - if (bigestType.getRank() < rank) { + if (bigestType.getRank() < rank) groupVectorType[opGroups.size() - 1] = vectorType; - } + return; } @@ -3054,7 +3105,7 @@ void VectorFusionStrategy::addOperationToGroup(Operation *op) { opGroups.emplace_back(std::queue()); if (not isa(op)) { - updateGroupBitgestVectorType(vectorType); + updateGroupBigestVectorType(vectorType); while (not noNeedToJudgeOps.empty()) { auto cur = noNeedToJudgeOps.front(); noNeedToJudgeOps.pop(); diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index 3cb8d26fb..eefdc4c3d 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -53,7 +53,7 @@ struct HardWareInfo { bool favx2 = true; }; -/// Using to avoid too many parameters in function +/// To avoid too many parameters in function when generate for loop struct GenerateLoopHelper { /// anchor id size_t anchorIdx = 0; @@ -125,83 +125,13 @@ struct GenerateLoopHelper { DenseMap &nextAnchorArgsIdxMap, SmallVector &nextAnchorArgs); + /// update loop iteration args data void updateCurrentArgsStatus(DenseMap ¤tArgsIdxMap, SmallVector ¤tArgs, DenseMap &originalArgsMap, DenseMap &argsOriginalMap); }; -void GenerateLoopHelper::setNextAnchorArgs( - DenseMap &nextAnchorArgsIdxMap, - SmallVector &nextAnchorArgs) { - currentLoopStateIdxMap = nextAnchorArgsIdxMap; - loopIterArgs = nextAnchorArgs; -} - -void GenerateLoopHelper::clearNextAnchorResults() { - nextAnchorResults.clear(); - nextAnchorResultsIdxMap.clear(); - nextAnchorResultOrignalResultMap.clear(); -} - -void GenerateLoopHelper::setAnchorId(size_t anchorId) noexcept { - anchorIdx = anchorId; -} - -void GenerateLoopHelper::updateDataBeforePreOpMove( - ArrayRef loopState, std::queue &candidateQueue, - std::queue &movedQueue) { - loopIterArgs = loopState; - candidateOps = &candidateQueue; - movedOps = &movedQueue; -} - -void GenerateLoopHelper::updateDataAfterPreOpMove( - DenseMap &nextAnchorArgsIdxMap, - SmallVector &nextAnchorArgs) { - setNextAnchorArgs(nextAnchorArgsIdxMap, nextAnchorArgs); -} - -void GenerateLoopHelper::updateDataBeforePostOpMove( - ArrayRef iterArgs, DenseMap ¤tLoopStateIdxMap, - DenseMap ¤toriginalArgsMap, - DenseMap ¤tArgsOriginalMap, ValueRange forResults, - Block *forBlock, std::queue &movedQueue, size_t anchorId) { - this->originalOperandLoopArgsMap = currentoriginalArgsMap; - this->loopArgsOriginalOperandMap = currentArgsOriginalMap; - this->forResults = forResults; - this->forBlock = forBlock; - this->anchorIdx = anchorId; - this->currentLoopStateIdxMap = currentLoopStateIdxMap; - this->loopIterArgs = iterArgs; - this->movedOps = &movedQueue; -} - -void GenerateLoopHelper::updateDataAfterPostOpMove( - size_t anchorId, DenseMap &nextAnchorArgsIdxMap, - SmallVector &nextAnchorArgs) { - setAnchorId(anchorId); - setNextAnchorArgs(nextAnchorArgsIdxMap, nextAnchorArgs); -} - -void GenerateLoopHelper::setNextAnchorResults( - SmallVector ¤tAnchorResults, - DenseMap ¤tResultMap, - DenseMap ¤tResultIdxMap) { - nextAnchorResults = std::move(currentAnchorResults); - nextAnchorResultOrignalResultMap = std::move(currentResultMap); - nextAnchorResultsIdxMap = std::move(currentResultIdxMap); -} - -void GenerateLoopHelper::updateCurrentArgsStatus( - DenseMap ¤tArgsIdxMap, SmallVector ¤tArgs, - DenseMap &originalArgsMap, - DenseMap &argsOriginalMap) { - setNextAnchorArgs(currentArgsIdxMap, currentArgs); - originalOperandLoopArgsMap = originalArgsMap; - loopArgsOriginalOperandMap = argsOriginalMap; -} - /// set correct operand for the operation void setOperationCorrectOperand( Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, @@ -209,6 +139,10 @@ void setOperationCorrectOperand( ArrayRef inductionVars, DenseMap &opPermuationMap); +//===----------------------------------------------------------------------===// +// vectorize operation class +//===----------------------------------------------------------------------===// + /// Vector type conversion helper class class TypeHelper { private: @@ -294,10 +228,11 @@ class VectorFusionStrategy : public TypeHelper { llvm::SmallVector &getGroupMaxSteps() noexcept { return groupMaxSteps; } + /// Get the map contains anchor position of each operation llvm::DenseMap &getOpAnchorPos() noexcept { return opAnchorPos; } - + /// Get current function IR func::FuncOp &getFunc() { return func; } /// Do fusion strategy void classifyOperations(); @@ -305,7 +240,8 @@ class VectorFusionStrategy : public TypeHelper { /// Whether two operations have compatible vector shapes bool isCompatibleVectorType(Operation *op1, Operation *op2); - void updateGroupBitgestVectorType(VectorType vectorType); + /// update bigest vector type for last operation group + void updateGroupBigestVectorType(VectorType vectorType); /// Check whether the operation can fuse with previous operation bool isNeedNewGroup(Operation *op); @@ -360,18 +296,26 @@ template class SpecialOperationCanonicalizer : virtual TypeHelper { }; enum class MultiReduceOpAxisKind { Reduction, Parallel }; +/// Help to vectorize reduction operation class MultiReductionCanonicalizer : public SpecialOperationCanonicalizer { private: + /// reduction parallel axis and reduction axis SmallVector reductionAxis, parallelAxis; + /// operations before reduction operation and operations after reduction + /// operation std::queue prevOps, postOps, accRelatedOps, sourceRelatedOps; bool haslastDimReduction = false; bool isStandaloneOp = false; /// empty reduction means that all the reduction axis is 1 bool isEmptyReduction = true; + /// vector type rank int64_t typeRank = -1; + /// record original operation result SetVector originalOpResults; + /// vector type of source operation and accumulate operation VectorType sourceType, accType; + /// for loop yield result index map llvm::SmallDenseMap resultIdxMap; public: @@ -450,7 +394,7 @@ class TransposeCanonicalizer : SpecialOperationCanonicalizer( candidateTpOps, SpecialOperationKind::OP_Transpose, steps) {}; virtual ~TransposeCanonicalizer() noexcept {} - void prepareSpecialOperationInfo() override; + void prepareSpecialOperationInfo() override {}; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose; } @@ -682,7 +626,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { DenseMap> &indiceLoopMap); void getResultInCurrentOps(const size_t anchorIdx, const size_t groupId, - const std::queue ops, + const std::queue &ops, SmallVector &results, DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap); @@ -843,8 +787,8 @@ class CanonicalizerVectorOperation : virtual public ForLoopGenerator, func::FuncOp &getFunc() noexcept { return func; }; IRRewriter &getIRWewriter() noexcept { return rewriter; } template - void processSpecialOperation(T &canonicalizers, - std::function generateFunc); + void processSpecialOperation( + T &canonicalizers, const std::function &generateFunc); // void canonicalizeSpecialOperation(); void clearSpecialOperationCanonicalizers(); From cdbe4e298a80d29ce7ff1359c4697bdf4b36307b Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 9 Sep 2024 16:24:18 +0800 Subject: [PATCH 44/66] fix format --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 18 ++++++++---------- lib/gc/Transforms/TilingVector.h | 5 ++--- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 03e151318..e540f5f25 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -444,7 +444,7 @@ mlir::FailureOr getOperationMaxVectorType(Operation *op) { return ret; } -VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loop_step) { +VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) { // Check that the operation type can be broken // down into a loop. mlir::FailureOr baseType = getOperationVectorType(op); @@ -454,10 +454,10 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loop_step) { return VectorType(); } auto vectorizedType = baseType.value(); - if (loop_step == 0) - loop_step = getDataTypeValidSteps(vectorizedType); + if (loopStep == 0) + loopStep = getDataTypeValidSteps(vectorizedType); - return VectorType::get({loop_step}, vectorizedType.getElementType()); + return VectorType::get({loopStep}, vectorizedType.getElementType()); } /// whether the operation result need to be returned @@ -2252,11 +2252,10 @@ void MultiReductionCanonicalizer::initReductionAxis() { void MultiReductionCanonicalizer::initParallelAxis() { llvm::SmallDenseSet reductionAxisSet(reductionAxis.begin(), reductionAxis.end()); - for (int64_t i = 0; i < typeRank; ++i) { - if (!reductionAxisSet.contains(i)) { + for (int64_t i = 0; i < typeRank; ++i) + if (!reductionAxisSet.contains(i)) parallelAxis.push_back(i); - } - } + llvm::sort(parallelAxis); } @@ -2369,7 +2368,6 @@ bool TransposeCanonicalizer::transposeOnLastDim() { return false; VectorType vtType = getCandidateOps()[0].getResultVectorType(); - if (vtType.getShape()[rank - 1] % getVectorStep() != 0) return false; @@ -2639,7 +2637,7 @@ void ForLoopGenerator::setOperationCorrectOperand( llvm::llvm_unreachable_internal( "Permuatation map must contains dim expr."); - size_t dim = 0; + int64_t dim = 0; if (auto d = dyn_cast(x)) { dim = d.getPosition(); } else if (auto d = dyn_cast(x)) { diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index eefdc4c3d..388de6cce 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -325,7 +325,6 @@ class MultiReductionCanonicalizer : SpecialOperationCanonicalizer( candidateRdOps, SpecialOperationKind::OP_MultiDimReduction, steps) { isStandaloneOp = candidateRdOps.size() == 1; - prepareSpecialOperationInfo(); }; virtual ~MultiReductionCanonicalizer() noexcept {}; int64_t getTypeRank(); @@ -611,7 +610,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { scf::ForOp constructNestedForOp(const size_t groupIdx, OpBuilder &b, const Location &loc, ArrayRef dims, - GenerateLoopHelper &loopGenerator); + GenerateLoopHelper &loopHelper); void moveOperationsToCurrentForBody(const OpBuilder &b, std::queue &queue, @@ -683,7 +682,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel); /// generate for loop for transpose operation - scf::ForOp generateTransposeForLoop(const size_t groupId); + scf::ForOp generateTransposeForLoop(const size_t grpIdx); scf::ForOp generateTransposeForLoopWithLastDim( OpBuilder &opBuilder, const int tpSteps, const Location &loc, Operation *successorWriteOp, GenerateLoopHelper &loopHelperParam); From 5bf4a9f2bd9df9e1ed15cf7c5a28de0814ca08ee Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 9 Sep 2024 16:29:22 +0800 Subject: [PATCH 45/66] enable mincrokernel op in vector --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 8 ++------ lib/gc/Transforms/TilingVector.h | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index e540f5f25..935047c14 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -31,8 +31,7 @@ namespace { linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::ExtractSliceOp, \ - tensor::InsertSliceOp -// , microkernel::BrgemmOp + tensor::InsertSliceOp, microkernel::BrgemmOp /// TODO: remove it in the future bool disableSpecialOp = false; @@ -2368,10 +2367,7 @@ bool TransposeCanonicalizer::transposeOnLastDim() { return false; VectorType vtType = getCandidateOps()[0].getResultVectorType(); - if (vtType.getShape()[rank - 1] % getVectorStep() != 0) - return false; - - return true; + return vtType.getShape()[rank - 1] % getVectorStep() != 0; } bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index 388de6cce..e2ba8bd5c 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -10,6 +10,7 @@ #include "gc/Analysis/TargetDescriptionAnalysis.h" #include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "gc/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -32,7 +33,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include -// #include "gc/Dialect/Microkernel/MicrokernelOps.h" namespace mlir { namespace gc { namespace { From b0a26cd72364343eaeb6a0894544cf5eb8ca3da3 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 9 Sep 2024 16:47:12 +0800 Subject: [PATCH 46/66] fix comments --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 935047c14..0633d80e0 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -2367,7 +2367,7 @@ bool TransposeCanonicalizer::transposeOnLastDim() { return false; VectorType vtType = getCandidateOps()[0].getResultVectorType(); - return vtType.getShape()[rank - 1] % getVectorStep() != 0; + return vtType.getShape()[rank - 1] % getVectorStep() == 0; } bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { @@ -2763,9 +2763,9 @@ bool VectorFusionStrategy::isCompatibleVectorType(Operation *op1, // some operation has two different operands type like multireduction, we need // to check whether compitable with accumulate vector VectorType suppleType; - if (failed(type1) || failed(type2)) { + if (failed(type1) || failed(type2)) return false; - } + auto sp1 = type1.value(); auto sp2 = type2.value(); From dfa5ea335b4fca5f08e4a092f96a3c9d9ad66339 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 10 Sep 2024 09:01:05 +0800 Subject: [PATCH 47/66] add comments --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 96 +++++---------- lib/gc/Transforms/TilingVector.h | 113 +++++++++++++----- .../gc/Transforms/cpu-vetor-distribution.mlir | 10 +- 3 files changed, 119 insertions(+), 100 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 0633d80e0..46e0c8e10 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -979,35 +979,6 @@ void updateReduceReadWriteOperationOperand( } } -vector::TransferReadOp ForLoopGenerator::cloneReductionTransferRead( - Value &source, OpBuilder &b, IRMapping &readMap, - const SmallVector ¶llelAxis, - SmallVector &inductionVars, bool lastDimReduction, - MultiReduceOpAxisKind rdKind) { - IRRewriter rewriter(b); - auto readOp = dyn_cast(source.getDefiningOp()); - assert(readOp && " Not transfer_read operation. Current multireduction " - "operation may have wrong analysis IR."); - - Operation *clonedOp = b.clone(*readOp, readMap); - auto newReadOp = cast(clonedOp); - updateReduceReadWriteOperationOperand(inductionVars, parallelAxis, newReadOp, - rdKind); - - // modify the type of the new read operation - auto newOperandType = - (lastDimReduction && rdKind == MultiReduceOpAxisKind::Reduction) - ? getVectorzedType(newReadOp) - : getScalarType(newReadOp); - newReadOp->getResult(0).setType(newOperandType); - setOpVectorizationPermutationMap( - newReadOp, b, cast(newReadOp.getSource().getType()), - newReadOp.getPermutationMap()); - - rewriter.replaceOp(readOp, newReadOp); - return newReadOp; -} - Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, int64_t x) { return opBuilder.create( @@ -1025,11 +996,7 @@ void ForLoopGenerator::moveOperationsToCurrentForBody( tmpQ.pop(); x->moveBefore(b.getBlock(), b.getBlock()->end()); // check operation type to set correct operand - setOperationCorrectOperand(x, loopHelperParam.loopIterArgs, - loopHelperParam.currentLoopStateIdxMap, - loopHelperParam.originalOperandLoopArgsMap, - loopHelperParam.inductionVars, opPermuationMap, - loopHelperParam.indiceLoopMap); + setOperationCorrectOperand(x, opPermuationMap, loopHelperParam); } } @@ -1923,11 +1890,10 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( }); } -/// generate simple data movement for loop scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const size_t steps, const Location &loc, SmallVector &inductionVars, - const ValueRange &iterArgs) { + ValueRange iterArgs) { auto &scCanonicalizer = getShapeCastCanonicalizers()[grpIdx]; vector::ShapeCastOp &scOp = scCanonicalizer.getCandidateOps()[0]; VectorType sourceType = scOp.getSourceVectorType(); @@ -2562,9 +2528,9 @@ void CanonicalizerVectorOperation::run() { canonicalizeSpecialOperation(); // 2.Generate vectorized IR for each operation group - for (size_t idx = 0; idx < fusionStrategy.getOpGroups().size(); ++idx) { + for (size_t idx = 0; idx < fusionStrategy.getOpGroups().size(); ++idx) generateGroupOpVectorizedIR(idx); - } + // 3. Some IR cleanup work DominanceInfo domInfo; eliminateCommonSubExpressions(rewriter, domInfo, func); @@ -2604,21 +2570,20 @@ void CanonicalizerVectorOperation::run() { /// void ForLoopGenerator::setOperationCorrectOperand( - Operation *op, ValueRange iterArgs, - const DenseMap &operandIdxMap, - DenseMap &originalOperandLoopArgsMap, - ArrayRef inductionVars, - const DenseMap &opPermuationMap, - DenseMap> &indiceloopMap) { + Operation *op, const DenseMap &opPermuationMap, + GenerateLoopHelper &loopHelperParam) { for (auto [idx, opd] : llvm::enumerate(op->getOperands())) { - if (not originalOperandLoopArgsMap.contains(opd)) + if (not loopHelperParam.originalOperandLoopArgsMap.contains(opd)) continue; - Value loopArg = originalOperandLoopArgsMap[opd]; - if (not operandIdxMap.contains(loopArg)) + Value loopArg = loopHelperParam.originalOperandLoopArgsMap[opd]; + if (not loopHelperParam.currentLoopStateIdxMap.contains(loopArg)) continue; - op->setOperand(idx, iterArgs[operandIdxMap.at(loopArg)]); + op->setOperand( + idx, + loopHelperParam + .loopIterArgs[loopHelperParam.currentLoopStateIdxMap.at(loopArg)]); } int offset = isa(op) ? 2 : 1; if (dyn_cast(op) || @@ -2643,17 +2608,22 @@ void ForLoopGenerator::setOperationCorrectOperand( ShapedType tensorType = cast(op->getOperandTypes()[offset - 1]); int64_t varIdx = dim; - if (tensorType.getRank() > (int64_t)inductionVars.size()) { - int64_t tensorOffset = tensorType.getRank() - inductionVars.size(); + if (tensorType.getRank() > + (int64_t)loopHelperParam.inductionVars.size()) { + int64_t tensorOffset = + tensorType.getRank() - loopHelperParam.inductionVars.size(); if (dim < tensorOffset) continue; varIdx = dim - tensorOffset; } - if (indiceloopMap.contains(op)) - op->setOperand(dim + offset, inductionVars[indiceloopMap[op][varIdx]]); + if (loopHelperParam.indiceLoopMap.contains(op)) + op->setOperand( + dim + offset, + loopHelperParam + .inductionVars[loopHelperParam.indiceLoopMap[op][varIdx]]); else - op->setOperand(dim + offset, inductionVars[varIdx]); + op->setOperand(dim + offset, loopHelperParam.inductionVars[varIdx]); } if (auto readOp = dyn_cast(op)) { size_t grpIdx = getFusionStrategy().getOpGroupIndexMap()[op]; @@ -2661,7 +2631,8 @@ void ForLoopGenerator::setOperationCorrectOperand( getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; SmallVector readIndices(readOp.getIndices().begin(), readOp.getIndices().end()); - rectifyReadOperationIndice(&readOp, loopType, inductionVars, readIndices); + rectifyReadOperationIndice(&readOp, loopType, + loopHelperParam.inductionVars, readIndices); readOp.getIndicesMutable().assign(readIndices); } } @@ -3040,7 +3011,7 @@ Operation *getNotReadWriteOperaiton(std::queue &tmpQ) { bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { if (isa(op)) { - noNeedToJudgeOps.push(op); + notNeedToJudgeOps.push(op); return false; } // 1. check previous operation @@ -3051,13 +3022,12 @@ bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { prevOp = getNotReadWriteOperaiton(tmpQ); if (!prevOp) { - if (opGroups.back().back()->getParentOp() != op->getParentOp()) { + if (opGroups.back().back()->getParentOp() != op->getParentOp() or + isSpecialOp(op)) { // if previous operation is not in the same block, we need to create a // group return true; } - if (isSpecialOp(op)) - return true; return false; } @@ -3100,9 +3070,9 @@ void VectorFusionStrategy::addOperationToGroup(Operation *op) { if (not isa(op)) { updateGroupBigestVectorType(vectorType); - while (not noNeedToJudgeOps.empty()) { - auto cur = noNeedToJudgeOps.front(); - noNeedToJudgeOps.pop(); + while (not notNeedToJudgeOps.empty()) { + auto cur = notNeedToJudgeOps.front(); + notNeedToJudgeOps.pop(); opGroupIndexMap[cur] = opGroups.size() - 1; opGroups.back().push(cur); } @@ -3194,7 +3164,6 @@ void setOperationOperandResult(Operation *op, const VectorType &newOperandType, x.setType(newOperandType); }; -/// Reimplementation of writing a tensor from a constant of denseElementattr. void ForLoopGenerator::createNewConstantOp( Operation *srcOp, vector::TransferWriteOp *transferWriteOp, size_t groupSteps) { @@ -3360,7 +3329,6 @@ void CanonicalizerCommonUsedData::removeOpInCurrentGroups( getFusionStrategy().getOpAnchorPos()[x] = getOperationMaxVectorType(x)->getRank() - 1; - // update operation in grpIdx group related information updateOpGroupInfo(grpIdx); } @@ -3935,7 +3903,7 @@ void moveSomeInterferenceOperation( // Pre-order traversal of each op // Record each operation position. Inorder to we can kown current operation // should move after which operation. - llvm::DenseMap operationPosition; + DenseMap operationPosition; SmallVector candidateOps; size_t opCounter = 0; diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.h index e2ba8bd5c..9ad3d074a 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.h @@ -189,7 +189,7 @@ class VectorFusionStrategy : public TypeHelper { DenseMap opAnchorPos; /// record some operations which not need to No need to judge whether can be /// fused - std::queue noNeedToJudgeOps; + std::queue notNeedToJudgeOps; public: VectorFusionStrategy() = default; @@ -263,9 +263,12 @@ class VectorFusionStrategy : public TypeHelper { /// and we directly convert the operations into physical register sizes. enum CanonicalizerKind { OperationsGroup, Operations }; +/// base class of special operation template class SpecialOperationCanonicalizer : virtual TypeHelper { private: + /// store current special operation SmallVector candidateRdOps; + /// vectorize step size_t vectorStep = 1; public: @@ -290,8 +293,11 @@ template class SpecialOperationCanonicalizer : virtual TypeHelper { llvm::SmallVector &getCandidateOps(); virtual ~SpecialOperationCanonicalizer() {} virtual void prepareSpecialOperationInfo() = 0; + /// get kind of speical operation SpecialOperationKind getKind() noexcept { return kind; } + /// set current operation group vectorize step void setVectorStep(size_t step) noexcept { vectorStep = step; } + /// get current operation group vectorize step size_t getVectorStep() noexcept { return vectorStep; } }; @@ -327,36 +333,56 @@ class MultiReductionCanonicalizer isStandaloneOp = candidateRdOps.size() == 1; }; virtual ~MultiReductionCanonicalizer() noexcept {}; + /// get reduction vector type, we use source operation type as reduction + /// vector type int64_t getTypeRank(); + /// get reduction operation reduction and parallel axis void getReductionAxisAndParallelAxis(); + /// whether last dim is reduction axis bool hasLastDimReduction(); + /// whether only reduction operation in current operation group bool getIsStandAloneOp() noexcept { return isStandaloneOp; } + /// get whether last dim is reduction axis bool getHasLastDimReduction() noexcept { return haslastDimReduction; } - bool getIsEmptyReduction() noexcept { return isEmptyReduction; } + /// initialize to get reduction axis void initReductionAxis(); + /// initialize to get parallel axis void initParallelAxis(); + /// get reduction axis SmallVector &getReductionAxis() noexcept { return reductionAxis; }; + /// get parallel axis SmallVector &getParallelAxis() noexcept { return parallelAxis; }; + /// get prev operation in current operation group std::queue &getPrevOps() noexcept { return prevOps; } + /// get post operation in current operation group std::queue &getPostOps() noexcept { return postOps; } + /// get accumulate operation in reduction operation std::queue &getAccRelatedOps() noexcept { return accRelatedOps; } + /// get source operation in reduction operation std::queue &getSourceRelatedOps() noexcept { return sourceRelatedOps; } + /// get reduction operation original result SetVector &getOriginalOpResults() noexcept { return originalOpResults; } + /// get source operation vector type VectorType getSourceType() noexcept { return sourceType; }; + /// get accumulate operation vector type VectorType getAccType() noexcept { return accType; }; + /// get result index map llvm::SmallDenseMap &getResultIdxMap() noexcept { return resultIdxMap; } + /// set result index map void setResultIdxMap(const llvm::SmallDenseMap &map) { resultIdxMap = map; } + /// initalize parallel, reduction axis, reduction operation type and whether + /// last dim is reduction axis void prepareSpecialOperationInfo() override; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -384,6 +410,7 @@ class BroadcastCanonicalizer class TransposeCanonicalizer : public SpecialOperationCanonicalizer { private: + /// first and second transpose axis size_t firstTpIdx = 0, secondTpIdx = 0; public: @@ -400,11 +427,15 @@ class TransposeCanonicalizer enum TRANSPOSE_KERNEL { KERNEL_16X16 = 16, }; - + /// get first transpose axis size_t getFirstTpIdx() noexcept { return firstTpIdx; } + /// get second transpose axis size_t getSecondTpIdx() noexcept { return secondTpIdx; } + /// whether transpose on two dimensions bool isTwoDTranspose(); + /// whether transpose on all dimension size is one bool isTransposeOnAllOneDim(); + /// whether transpose on last dimension bool transposeOnLastDim(); }; @@ -422,6 +453,7 @@ class ShapeCastCanonicalizer static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_ShapeCast; } + /// whether store and load on last dimension bool isReadWriteOnLastDim(); }; @@ -441,6 +473,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { /// analysis the operation's operands and results SmallVector>, 8> groupOpResults; + /// store loop iteration args for each of operation group SmallVector, 8> groupOpInitArgs; // store read and write operations permutation maps in order to convenient @@ -491,12 +524,12 @@ class CanonicalizerCommonUsedData : public TypeHelper { groupOpResults = std::move(results); } - void - setGroupOpIterArgs(const SmallVector, 8> &initArgs) { + void setGroupOpIterArgs( + const SmallVector, 8> &initArgs) noexcept { groupOpInitArgs = std::move(initArgs); } - void setPermutationMap(const DenseMap &map) { + void setPermutationMap(const DenseMap &map) noexcept { opPermuationMap = std::move(map); } @@ -536,9 +569,10 @@ class CanonicalizerCommonUsedData : public TypeHelper { return shapeCastCanonicalizers; } - // other methods + /// whether \param grpIdx operation group has special operation bool isGroupHasSpecialOperation(const size_t grpIdx); + /// make emtpy tensor and write the operation result to the tensor void generateEmptyTensorAndWrite( Operation *sourceOp, llvm::DenseMap> @@ -546,18 +580,23 @@ class CanonicalizerCommonUsedData : public TypeHelper { size_t anchorPos, ReturnTypeKind retKind, DenseMap &visitedOperation); + /// update \param opGid operation group void updateOpOperandResultInGroups(size_t opGid, Operation *op, const Value &init = Value(), const Value &result = Value()); + /// replace \param op in \param grpIdx operation group with \param replacedOp void removeOpInCurrentGroups(size_t grpIdx, Operation *op, Operation *replacedOp); + /// update operation in grpIdx group related information void updateOpGroupInfo(size_t grpIdx); + /// make a transfer_read operation and read the producer operation result Value canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, size_t operandIdx, vector::TransferReadOp *srcReadOp = nullptr); + /// make a transfer_read operation Operation * createTransferReadOpBefore(Operation *op, const Value &operand, vector::TransferReadOp *srcReadOp = nullptr); @@ -570,6 +609,7 @@ class CanonicalizerCommonUsedData : public TypeHelper { /// generate for loop for each operation. class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { private: + /// currrent function IR func::FuncOp func; public: @@ -579,7 +619,9 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { virtual ~ForLoopGenerator() noexcept {} void setGeneratorFunc(func::FuncOp &func) noexcept { this->func = func; } + /// clear current group operation void clearCurrentOperationGroup(size_t grpIdx); + /// vectorize operations in current operation group void generateGroupOpVectorizedIR(const int idx); /// prepare for loop iteration args @@ -597,17 +639,18 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { DenseMap> &indiceLoopMap, const size_t groupId, Operation *op, const DenseMap &setIdxMap = DenseMap({})); + /// rewrite operation as vectorize IR in current operation group void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, const std::queue *queue = nullptr); + /// Reimplementation of writing a tensor from a constant of denseElementattr. void createNewConstantOp(Operation *srcOp, vector::TransferWriteOp *transferWriteOp, size_t groupSteps); - // elementwise for loop + // Generate elementwise operation for loop mlir::FailureOr generateVectorizedForLoop(const size_t groupId, IRRewriter &rewriter, const VectorType vectorType); - scf::ForOp constructNestedForOp(const size_t groupIdx, OpBuilder &b, const Location &loc, ArrayRef dims, GenerateLoopHelper &loopHelper); @@ -616,41 +659,42 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { std::queue &queue, GenerateLoopHelper &loopHelperParam); + /// Set correct operand with loop args for the operation void setOperationCorrectOperand( - Operation *op, ValueRange iterArgs, - const DenseMap &operandIdxMap, - DenseMap &originalOperandLoopArgsMap, - ArrayRef inductionVars, - const DenseMap &opPermuationMap, - DenseMap> &indiceLoopMap); + Operation *op, const DenseMap &opPermuationMap, + GenerateLoopHelper &loopHelperParam); + /// Get current anchor return retults void getResultInCurrentOps(const size_t anchorIdx, const size_t groupId, const std::queue &ops, SmallVector &results, DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap); - /// get next anchor's iteration loop args + /// Get next anchor's iteration loop args void getInitArgsToNextAnchor(llvm::DenseMap &nextAnchorArgsIdxMap, llvm::SmallVector &nextAnchorArgs, GenerateLoopHelper &loopHelperParam); - + /// Get operation should appear in current loop anchor void getOperationInCurrentAnchor(const size_t anchorIdx, std::queue &fromQueue, std::queue &toQueue); - /// get current loop operation result + /// Get current loop operation result void generateLoopResults(OpBuilder &b, const Location &loc, GenerateLoopHelper &loopHelperParam, DenseMap &nextOperandIdxMap); - /// todo: need to add a struct to remove so many parameters + /// Move post operations in current operation group to the for loop body void movePostOpToCurrentAnchor(OpBuilder &b, GenerateLoopHelper &loopHelperParam); + /// Move previous operations in current operation group to the for loop body void movePreOpToCurrentAnchor(OpBuilder &b, DenseMap &nextLoopStateIdxMap, SmallVector &nextAnchorArgs, GenerateLoopHelper &loopHelperParam); + /// replace moved operation result used by current post operations with for + /// loop result void replaceOperationsWithForLoopResult( IRRewriter &rewrite, const std::queue &movingOperations, GenerateLoopHelper &loopHelperParam); @@ -662,42 +706,41 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { const size_t grpIdx, DenseMap> &indiceLoopMap); + /// reduction operation reduction axis for loop scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder, const size_t reductionIdx, GenerateLoopHelper &loopHelperParam); - + /// reduction operation parallel axis for loop scf::ForOp parallelAxisGenerateForLoop(OpBuilder &opBuilder, GenerateLoopHelper &loopHelperParam); - + /// ensure accumulate operation appear in parallel loop, inorder to have + /// correct reduce fusion void ensureAccInParallelLoop(GenerateLoopHelper &loopHelperParam, ArrayRef parallelAxis, Value multiReductionAcc, DenseMap &nextAnchorArgsIdxMap, SmallVector &nextAnchorArgs); - vector::TransferReadOp cloneReductionTransferRead( - Value &source, OpBuilder &b, IRMapping &readMap, - const llvm::SmallVector ¶llelAxis, - SmallVector &inductionVars, bool lastDimReduction, - MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel); - /// generate for loop for transpose operation scf::ForOp generateTransposeForLoop(const size_t grpIdx); + /// shuffle instruction optimize for transpose operation scf::ForOp generateTransposeForLoopWithLastDim( OpBuilder &opBuilder, const int tpSteps, const Location &loc, Operation *successorWriteOp, GenerateLoopHelper &loopHelperParam); + /// generate transpose operation for loop of simple data movement scf::ForOp generateTransposeScalarDataMovement(OpBuilder &opBuilder, const Location &loc, DenseMap &tpAxisMap, GenerateLoopHelper &loopHelperParam); - // shapecast + /// generate shapecast operation for loop scf::ForOp generateShapeCastForLoop(const size_t grpIdx); + /// generate simple data movement for loop scf::ForOp generateShapeCastReadWriteLoop( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const size_t steps, const Location &loc, - SmallVector &inductionVars, const ValueRange &iterArgs); + SmallVector &inductionVars, ValueRange iterArgs); /// rectify indice for transfer_write operation /// e.g.: vector.transfer_write"(%16, %9, %c0, %c0), the first %c0 should use /// original indice not create by us @@ -756,11 +799,13 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { void replaceConstantOpAsNewOp(Operation *op, Operation *sourceOp, size_t operandIdx); }; -/// Vectorize vector operation with target machines simd instructions. +/// Vectorize vector operation with target machines max simd length. class CanonicalizerVectorOperation : virtual public ForLoopGenerator, VectorOperationAnalyzer { private: + /// current function IR func::FuncOp func; + /// rewriter of func operation IRRewriter rewriter; CanonicalizerKind kind; @@ -785,15 +830,19 @@ class CanonicalizerVectorOperation : virtual public ForLoopGenerator, // get functions func::FuncOp &getFunc() noexcept { return func; }; IRRewriter &getIRWewriter() noexcept { return rewriter; } + /// generate for loop for current special operation use \param generateFunc template void processSpecialOperation( T &canonicalizers, const std::function &generateFunc); - // + // Canonicalize special operation void canonicalizeSpecialOperation(); + /// clear special operation canonicalizer container void clearSpecialOperationCanonicalizers(); + /// add a dummy special canonicalizer void dummyInitSpecialOperation(size_t steps); + /// initialize all the speical operation canonicalizer void initSpeicalOperationCanonicalizers(); - + /// run the vector canonicalizer for the IR void run(); }; } // namespace diff --git a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir index 081e13af8..ccf2c9676 100644 --- a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir +++ b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir @@ -53,10 +53,12 @@ func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4 // CHECK: scf.yield // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x64xf32> // CHECK: scf.yield -// CHECK: scf.for -// CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: scf.for %[[arg1:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg2:.*]] = %[[EMPTY1]]) -> (tensor<16x1x64xf32>) +// CHECK: scf.for %[[arg3:.*]] = %[[C0]] to %[[C1]] step %[[C1]] iter_args(%[[arg4:.*]] = %[[arg2]]) -> (tensor<16x1x64xf32>) +// CHECK: scf.for %[[arg5:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg6:.*]] = %[[arg4]]) -> (tensor<16x1x64xf32>) +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map0]]()[%[[arg3]], %[[arg5]]] +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16x64xf32>, vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<16x1x64xf32> func.func @reduce_keepdimtest1(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf32> { %0 = tensor.empty() : tensor<16x64xf32> %reduce = linalg.reduce From ffc5569f43479d7136fbbe02938c5eb1bf18e26f Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 10 Sep 2024 09:11:15 +0800 Subject: [PATCH 48/66] fix clang-tidy --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 46e0c8e10..6a0410896 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -3021,15 +3021,10 @@ bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { Operation *prevOp = nullptr; prevOp = getNotReadWriteOperaiton(tmpQ); if (!prevOp) { - - if (opGroups.back().back()->getParentOp() != op->getParentOp() or - isSpecialOp(op)) { - // if previous operation is not in the same block, we need to create a - // group - return true; - } - - return false; + // if previous operation is not in the same block, we need to create a + // group + return opGroups.back().back()->getParentOp() != op->getParentOp() or + isSpecialOp(op); } if (prevOp->getParentOp() != op->getParentOp()) From 32f20ddb9564d476aa5ee595cc612f3868d9f2d1 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 10 Sep 2024 10:11:15 +0800 Subject: [PATCH 49/66] remove unused function --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 35 +++---------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 6a0410896..5459bf1ec 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -733,19 +733,6 @@ scf::YieldOp maybeYieldValue(OpBuilder &b, Location loc, return b.create(loc); } -Type getScalarType(Operation *op) { - // Check that the operation type can be broken - // down into a loop. - auto baseType = getOperationVectorType(op); - if (failed(baseType)) { - LDBG("Failed to get vector type for operation: " << *op << "\n"); - assert(false && "Failed to get vector type for operation"); - return VectorType(); - } - auto vectorizedType = baseType.value(); - return VectorType::get({1}, vectorizedType.getElementType()); -} - Operation *createTensorEmptyBefore(Operation *op) { auto rtType = cast(op->getResultTypes()[0]); @@ -966,19 +953,6 @@ void classifyAccRelatedOps(std::queue &accRelatedOps, } } -void updateReduceReadWriteOperationOperand( - const SmallVector &inductionVars, - const SmallVector ¶llelAxis, Operation *op, - MultiReduceOpAxisKind rdKind = MultiReduceOpAxisKind::Parallel) { - int indiceOffset = isa(op) ? 1 : 2; - for (auto [idx, inductionVar] : llvm::enumerate(inductionVars)) { - if (rdKind == MultiReduceOpAxisKind::Parallel && idx >= parallelAxis.size()) - break; - - op->setOperand(idx + indiceOffset, inductionVar); - } -} - Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, int64_t x) { return opBuilder.create( @@ -1359,7 +1333,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, &movingOperation); - + loopHelperParam.loopIterArgs = loopState; moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam); loopHelperParam.movedOps = &movingOperation; loopHelperParam.candidateOps = &opQueue; @@ -2104,15 +2078,13 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { if (canVectorizedLoadStore) { forOp = generateShapeCastReadWriteLoop( b, grpIdx, 0, groupStep, scOp.getLoc(), inductionVars, iterArgs); - for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) - rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); } else { // scalar data movement forOp = generateShapeCastReadWriteLoop(b, grpIdx, 0, 1, scOp.getLoc(), inductionVars, iterArgs); - for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) - rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); } + for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) + rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); rewriter.eraseOp(scOp); clearCurrentOperationGroup(grpIdx); @@ -2658,6 +2630,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( if (loopHelper.anchorIdx == dims.size() - 1) { std::queue &opQueue = getFusionStrategy().getOpGroups()[groupIdx]; + loopHelper.loopIterArgs = loopState; // 1. get operations in current anchor position std::queue movingOperation; getOperationInCurrentAnchor(loopHelper.anchorIdx, opQueue, From a4382c200f05454a4d8b39faa16770a9bb9dc040 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 13 Sep 2024 13:57:36 +0800 Subject: [PATCH 50/66] split analysis file --- .../gc/Analysis/VectorBasedFusionAnalysis.h | 284 +++++ lib/gc/Analysis/CMakeLists.txt | 1 + lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 720 +++++++++++ lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 1064 +++-------------- lib/gc/Transforms/Pipeline.cpp | 4 +- .../{TilingVector.h => TilingVector.hpp} | 547 +++------ .../gc/Transforms/cpu-vetor-distribution.mlir | 46 +- 7 files changed, 1409 insertions(+), 1257 deletions(-) create mode 100644 include/gc/Analysis/VectorBasedFusionAnalysis.h create mode 100644 lib/gc/Analysis/VectorBasedFusionAnalysis.cpp rename lib/gc/Transforms/{TilingVector.h => TilingVector.hpp} (71%) diff --git a/include/gc/Analysis/VectorBasedFusionAnalysis.h b/include/gc/Analysis/VectorBasedFusionAnalysis.h new file mode 100644 index 000000000..8a0d354c6 --- /dev/null +++ b/include/gc/Analysis/VectorBasedFusionAnalysis.h @@ -0,0 +1,284 @@ +//===-- VectorBasedFusionAnalysis.h - vector fusion analysis ----*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_VECTORBASEDFUSIONANALYSIS_H +#define MLIR_ANALYSIS_VECTORBASEDFUSIONANALYSIS_H + +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Microkernel/MicrokernelOps.h" +#include "gc/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +namespace mlir { +namespace gc { + +mlir::FailureOr getOperationVectorType(Operation *op, + bool isPrevOp = true); +int getNearestVectorStep(const int step); +mlir::FailureOr getOperationMaxVectorType(Operation *op); + +/// record hardware information +struct HardWareInfo { + bool favx512f = true; + bool favx2 = true; +}; + +/// Vector type conversion helper class +class TypeHelper { +private: + HardWareInfo info; + +public: + TypeHelper() = default; + TypeHelper(HardWareInfo info) : info(info) {} + /// get current hardware information + HardWareInfo &getHardwareInfo() { return this->info; } + /// use \param info to set hardware information + void setHardWareInfo(HardWareInfo &info) { this->info = info; } + /// get vector \param type max loop step according to hardware information + int getDataTypeValidSteps(VectorType type); + /// get vector \param type an even for loop step + int generateValidSteps(int steps, VectorType type); + /// get vector \param type max simd length according to hardware information + int getDataTypeMAXSIMDLength(VectorType type); + /// get operation's vector type + VectorType getVectorzedType(Operation *op, uint32_t loopStep = 0); +}; + +/// operation return kind, which is used to determine whether the operation +/// need to return it's result in current for loop +enum class ReturnTypeKind { + RT_Both, + RT_OutGroup, + RT_InGroup, +}; + +class VectorFusionBase { + +private: + /// current function IR + func::FuncOp func; + /// Type helper class, can help us to get operation type + TypeHelper typehelper; + +public: + VectorFusionBase() = default; + VectorFusionBase(func::FuncOp &func, HardWareInfo &info) + : func(func), typehelper(info) {} + VectorFusionBase(VectorFusionBase &base) + : func(base.getFunction()), typehelper(base.getHardwareInfo()) {} + + /// get current function IR + func::FuncOp &getFunction() { return func; } + /// get current hardware info + HardWareInfo &getHardwareInfo() { return typehelper.getHardwareInfo(); } + TypeHelper &getTypeHelper() { return typehelper; } +}; + +/// Group operation fusion strategy class. +/// 1. Classify operaions: +/// classify the operations into : +/// a. reorder, transpose. Reorder(or transpose) dim may bring data +/// dependency. +/// b. elemenwise. Those operations can be fused into a common for loop. +/// c. broadcast. Need to analysis broadcast dim and the data +/// dependency. +/// d. reduction. Need to analysis broadcast dim and the +/// data dependency. +/// Same group operations have no data dependencies. They can be fused into a +/// common for loop body. + +/// Using queue to store the operation order. In order to ensure that +/// subsequent moves to the operation will not cause semantic changes. +class GroupOperationFusion : public VectorFusionBase { +private: + /// operation groups, operations in each group can generate a common for + /// loop + SmallVector, 8> opGroups; + /// group max vectorize steps + SmallVector groupMaxSteps; + /// vector type which has bigest rank in current operation group + llvm::SmallDenseMap groupBigestRankVectorType; + /// query current operation in which group, return group index + DenseMap opGroupIndexMap; + /// can fused into prev operation which axis position + DenseMap opAnchorPos; + /// record some operations which not need to No need to judge whether can be + /// fused + std::queue notNeedToJudgeOps; + /// analysis the operation's operands and results + SmallVector>, 8> + groupOpResults; + /// store loop iteration args for each of operation group + SmallVector, 8> groupOpInitArgs; + // store read and write operations permutation maps in order to convenient + // to replace loop induction var + DenseMap opPermuationMap; + +public: + GroupOperationFusion(func::FuncOp &func, HardWareInfo &info) + : VectorFusionBase(func, info) {} + + GroupOperationFusion(GroupOperationFusion &strategy) + : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()), + opGroups(strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps), + opGroupIndexMap(strategy.opGroupIndexMap), + opAnchorPos(strategy.opAnchorPos) {}; + + GroupOperationFusion(GroupOperationFusion &&strategy) + : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()), + opGroups(std::move(strategy.opGroups)), + groupMaxSteps(std::move(strategy.groupMaxSteps)), + groupBigestRankVectorType( + std::move(strategy.getGroupBiggestRankVectorType())), + opGroupIndexMap(std::move(strategy.opGroupIndexMap)), + opAnchorPos(std::move(strategy.opAnchorPos)) {}; + + GroupOperationFusion &operator=(GroupOperationFusion &fusion) { + this->getOpGroups() = fusion.getOpGroups(); + this->getGroupMaxSteps() = fusion.getGroupMaxSteps(); + this->getGroupBiggestRankVectorType() = + fusion.getGroupBiggestRankVectorType(); + this->getOpGroupIndexMap() = fusion.getOpGroupIndexMap(); + this->getOpAnchorPos() = fusion.getOpAnchorPos(); + this->notNeedToJudgeOps = fusion.notNeedToJudgeOps; + this->getGroupOpResults() = fusion.getGroupOpResults(); + this->getGroupOpInitArgs() = fusion.getGroupOpInitArgs(); + this->getOpPermuationMap() = fusion.getOpPermuationMap(); + this->getFunction() = fusion.getFunction(); + this->getHardwareInfo() = fusion.getHardwareInfo(); + this->getTypeHelper() = fusion.getTypeHelper(); + return *this; + }; + GroupOperationFusion &operator=(GroupOperationFusion &&) = default; + + /// Get the map which contains each group vector type which has biggest + /// rank. + llvm::SmallDenseMap & + getGroupBiggestRankVectorType() noexcept { + return groupBigestRankVectorType; + }; + /// Get the operation group obtained by fusion strategy analysis + SmallVector, 8> &getOpGroups() noexcept { + return opGroups; + } + /// Get the operation belong to which group index map + DenseMap &getOpGroupIndexMap() noexcept { + return opGroupIndexMap; + } + /// Get the map contains max steps of each group + SmallVector &getGroupMaxSteps() noexcept { + return groupMaxSteps; + } + /// Get the map contains anchor position of each operation + DenseMap &getOpAnchorPos() noexcept { + return opAnchorPos; + } + /// get current operation group results + SmallVector>, 8> & + getGroupOpResults() noexcept { + return groupOpResults; + } + + SmallVector, 8> &getGroupOpInitArgs() noexcept { + return groupOpInitArgs; + } + + DenseMap &getOpPermuationMap() noexcept { + return opPermuationMap; + } + /// set operation groups + void setGroupOpResults( + const SmallVector< + llvm::MapVector>, 8> + &results) { + groupOpResults = std::move(results); + } + + void setGroupOpIterArgs( + const SmallVector, 8> &initArgs) noexcept { + groupOpInitArgs = std::move(initArgs); + } + + void setPermutationMap(const DenseMap &map) noexcept { + opPermuationMap = std::move(map); + } + /// Do fusion strategy + void classifyOperations(); + + /// Whether two operations have compatible vector shapes + bool isCompatibleVectorType(Operation *op1, Operation *op2); + + /// update bigest vector type for last operation group + void updateGroupBigestVectorType(VectorType vectorType); + + /// Check whether the operation can fuse with previous operation + bool isNeedNewGroup(Operation *op); + + /// Add Operation \p op into current last group or a new Group + /// \p op must has valid value, can't be nullptr + void addOperationToGroup(Operation *op); + + /// get next operation in current operation group + template + Operation *getNextTargetOperationInCurrentGroup(Operation *curOp, + const size_t grpIdx); + + /// run the vector-based fusion strategy + void run(); +}; + +template +Operation *GroupOperationFusion::getNextTargetOperationInCurrentGroup( + Operation *curOp, const size_t grpIdx) { + std::queue tmpOpQueue(getOpGroups()[grpIdx]); + if (isa(curOp)) + return curOp; + + while (!tmpOpQueue.empty()) { + auto frontOp = tmpOpQueue.front(); + if (isa(frontOp)) { + for (auto x : frontOp->getOperands()) + if (x.getDefiningOp() == curOp) + return frontOp; + } + tmpOpQueue.pop(); + } + return nullptr; +} + +class GroupOperationAnalysis { +private: + /// vector-based fusion related data + GroupOperationFusion fusionStrategy; + +public: + GroupOperationAnalysis(func::FuncOp &func, HardWareInfo &info) + : fusionStrategy(func, info) {} + /// remove the useless operation, due to it result is not require by other + /// operation + void analysisEmptyGroup(); + /// get each operation in each group maximum support vectorization length + void analysisGroupMaxSteps(); + /// get fusion strategy + GroupOperationFusion &getGroupOperationFusion() { return fusionStrategy; } + + void run() { fusionStrategy.run(); } +}; +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index d7160f350..b11eb3bd4 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -5,6 +5,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS gc_add_mlir_library(GcAnalysis TargetDescriptionAnalysis.cpp MatmulConfigAnalysis.cpp + VectorBasedFusionAnalysis.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp new file mode 100644 index 000000000..2760ada39 --- /dev/null +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -0,0 +1,720 @@ +//===- VectorBasedFusionAnalysis.cpp - analysis vector ops ------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "gc/Analysis/VectorBasedFusionAnalysis.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +namespace mlir { +namespace gc { + +#define DEBUG_TYPE "vector-operation-analysis" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define SAFE_EXPAND(X) X +#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") + +#define ARITH_CAST_OPERATIONS \ + arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp, arith::BitcastOp, \ + arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp, \ + arith::TruncFOp, arith::TruncIOp + +#define NOT_NEED_TO_PROCESS_OP \ + linalg::GenericOp, linalg::BatchReduceMatmulOp, linalg::MatmulOp, \ + linalg::BatchMatmulOp, linalg::BatchMatmulTransposeAOp, \ + linalg::BatchMatmulTransposeBOp, linalg::MatmulTransposeAOp, \ + linalg::MatmulTransposeBOp, linalg::QuantizedBatchMatmulOp, \ + linalg::QuantizedMatmulOp, tensor::CollapseShapeOp, \ + tensor::ExpandShapeOp, tensor::ExtractSliceOp, tensor::InsertSliceOp, \ + microkernel::BrgemmOp + +static inline bool isNotNeedToProcessOp(Operation *op) { + return isa(op); +} + +static inline bool isSpecialOp(Operation *op) { + return isa( + op); +} + +static inline bool isReadOrWriteOperation(Operation *op) { + return isa(op); +} + +/// which axis do the shape cast in source shape a +void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, + SmallVector &res) { + unsigned rankA = a.size(); + unsigned rankB = b.size(); + if (rankA >= rankB) + llvm::llvm_unreachable_internal("May be invalid shape cast operation."); + + auto isOne = [](int64_t v) { return v == 1; }; + + // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape + // casted to a 0-d vector. + if (rankA == 0 && all_of(b, isOne)) { + for (size_t i = 0; i < a.size(); i++) { + res.emplace_back(i); + } + return; + } + + unsigned i = 0; + unsigned j = 0; + while (i < rankA && j < rankB) { + int64_t dimA = a[i]; + int64_t dimB = 1; + int64_t bAxisBegin = j; + while (dimB < dimA && j < rankB) + dimB *= b[j++]; + if (dimA != dimB) { + assert(false && " Invalid shape cast operation."); + break; + } + if (bAxisBegin != j) { + res.emplace_back(i); + } + ++i; + + // Handle the case when trailing dimensions are of size 1. + // Include them into the contiguous sequence. + if (i < rankA && all_of(a.slice(i), isOne)) + i = rankA; + if (j < rankB && all_of(b.slice(j), isOne)) + j = rankB; + } + + assert(i == rankA && j == rankB && "Invalid shapecast operation."); +} + +bool isScalar(Type type) { + assert(type && "Not a valid type"); + if (auto vecType = dyn_cast(type)) + return false; + if (auto tensorType = dyn_cast(type)) + return false; + return true; +} + +void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output, + SmallVector &bcAxis) { + auto inputShape = input.getShape(); + auto outputShape = output.getShape(); + // following auto_broadcast semantics + const size_t input_rank = inputShape.size(); + const size_t output_rank = outputShape.size(); + assert(output_rank >= input_rank && + "Incorrect input or output shape for broadcast op."); + const size_t offset = output_rank - input_rank; + for (size_t i = 0; i < input_rank; ++i) { + if (inputShape[i] == outputShape[i + offset] || + (ShapedType::isDynamic(inputShape[i]) && + ShapedType::isDynamic(outputShape[i + offset]))) { + bcAxis.emplace_back(i); + } + } + if (bcAxis.empty()) + bcAxis.emplace_back(-1); +} + +void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { + return TypeSwitch(op) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + auto rdDimsRange = multiReductionOp.getReductionDims(); + dataAxis.assign(rdDimsRange.begin(), rdDimsRange.end()); + return; + }) + .Case([&](vector::ShapeCastOp shapeCastOp) { + auto srcType = shapeCastOp.getSourceVectorType(); + auto dstType = shapeCastOp.getResultVectorType(); + auto srcShape = srcType.getShape(); + auto dstShape = dstType.getShape(); + if (srcShape.size() < dstShape.size()) { + shapeCastSourceAxis(srcShape, dstShape, dataAxis); + } else { + shapeCastSourceAxis(dstShape, srcShape, dataAxis); + } + return; + }) + .Case([&](vector::BroadcastOp broadcastOp) { + auto srcType = broadcastOp.getSourceType(); + auto dstType = broadcastOp.getResultVectorType(); + if (isScalar(srcType)) { + dataAxis.emplace_back(0); + } else { + auto inputType = mlir::cast(srcType); + auto outputType = mlir::cast(dstType); + getSrcBroadcastDim(inputType, outputType, dataAxis); + } + return; + }) + .Case([&](vector::TransposeOp transposeOp) { + auto perm = transposeOp.getPermutation(); + int start = 0; + for (auto x : perm) { + if (x != start) { + dataAxis.emplace_back(x); + } + start++; + } + return; + }) + .Default([&](Operation *op) { + // default is last axis + dataAxis.emplace_back( + cast(op->getResultTypes()[0]).getRank() - 1); + return; + }); +} + +static inline bool hasSameAxis(ArrayRef dims1, + ArrayRef dims2) { + DenseSet checkSet(dims2.begin(), dims2.end()); + return llvm::any_of(dims1, + [&checkSet](int64_t x) { return checkSet.contains(x); }); +} + +/// whether op2 use op1 result +/// Currently we just enable this function for write and read operation +template || + std::is_same_v, + T>> +static bool isOperationsHasDefUseRelation(Operation *op1, Operation *op2) { + return llvm::any_of(op2->getOperands(), + [&op1](Value opd) { return opd.getDefiningOp() == op1; }); +} + +/// whether two operation has data dependency +/// op1 default is previous operation, op2 default is current operation +bool hasDataDependency(Operation *op1, Operation *op2) { + if (!isSpecialOp(op1) and !isSpecialOp(op2)) + return false; + + if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { + // if op1 is read the value and pass it to op2, it is not data dependency + if (isOperationsHasDefUseRelation(op1, op2)) + return false; + } + + // broadcast only fuse with post-op + if (isa(op2)) + return true; + + if (isa(op1)) + return true; + + // only special operation may cause data dependency + if (!isSpecialOp(op1)) + return hasDataDependency(op2, op1); + + auto res = + TypeSwitch(op1) + .Case([&](vector::ShapeCastOp shapeCastOp) { + SmallVector dims1, dims2; + getOperationDataAxis(op1, dims1); + getOperationDataAxis(op2, dims2); + if (!isSpecialOp(op2)) + return hasSameAxis(dims1, dims2); + + return true; + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + SmallVector dims2, reductionDims, parallelDims; + getOperationDataAxis(op1, reductionDims); + getOperationDataAxis(op2, dims2); + DenseSet checkSet(dims2.begin(), dims2.end()); + auto op2VectorType = getOperationVectorType(op2); + if (!isSpecialOp(op2)) { + // all reduction axis should be op2's data axis + bool reduceDependent = false; + for (auto x : reductionDims) { + if (!checkSet.contains(x)) { + reduceDependent = true; + break; + } + } + if (!reduceDependent) + return false; + + // all parallel axis should equal to op2's axis + checkSet.clear(); + checkSet.insert(reductionDims.begin(), reductionDims.end()); + auto rdRank = + multiReductionOp.getSourceVectorType().getRank(); + for (auto i = 0; i < rdRank; i++) + if (not checkSet.contains(i)) + parallelDims.emplace_back(i); + + checkSet.clear(); + checkSet.insert(parallelDims.begin(), parallelDims.end()); + auto rank = op2VectorType->getRank(); + for (auto i = 0; i < rank; i++) + if (!checkSet.contains(i)) + return true; + + return false; + } + + return true; + }) + .Case([&](vector::BroadcastOp broadcastOp) { + if (isSpecialOp(op2)) + return true; + + return !OpTrait::util::staticallyKnownBroadcastable( + getOperationVectorType(op1, false)->getShape(), + getOperationVectorType(op2)->getShape()); + }) + .Case( + [&](vector::TransposeOp transposeOp) { return true; }) + .Default([&](Operation *op) { return false; }); + + return res; +} + +/// Get vector type of the operation \param op +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationVectorType(Operation *op, + bool isPrevOp) { + if (not op) + return failure(); + + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + if (auto retType = dyn_cast( + transferWriteOp.getOperandTypes()[0])) + return retType; + + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + if (isPrevOp) + return cast( + multiReductionOp->getResultTypes()[0]); + + // TODO: may need to add accumulate value vectortype + return cast(multiReductionOp.getSourceVectorType()); + }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (isPrevOp) { + if (op->getResultTypes().empty()) + return failure(); + + if (auto shapedType = + dyn_cast(op->getResultTypes()[0])) + return shapedType; + + return failure(); + } + if (op->getOperandTypes().empty()) + return failure(); + + if (auto shapedType = + dyn_cast(op->getOperandTypes()[0])) + return shapedType; + + return failure(); + }); + if (!failed(ret) and isDynamicType(ret.value())) + return failure(); + + return ret; +} + +/// get operation vector type +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationMaxVectorType(Operation *op) { + if (not op) + return failure(); + + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + if (auto retType = + cast(transferWriteOp.getOperandTypes()[0])) + return retType; + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + return cast(multiReductionOp.getSourceVectorType()); + }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (op->getResultTypes().empty() and op->getOperandTypes().empty()) + return failure(); + + if (op->getResultTypes().empty()) + return cast(op->getOperandTypes()[0]); + + if (op->getOperandTypes().empty()) + return cast(op->getResultTypes()[0]); + + auto opdType = cast(op->getOperandTypes()[0]); + auto retType = cast(op->getResultTypes()[0]); + return opdType.getRank() > retType.getRank() ? opdType : retType; + }); + if (!failed(ret) and isDynamicType(ret.value())) + return failure(); + + return ret; +} + +/// select nearest even step +int getNearestVectorStep(const int step) { + assert(step > 0); + int nbits = 0, n = step; + while (n) { + n = n >> 1; + nbits++; + } + assert(nbits <= 6 || (nbits == 7 && step == 64)); + return (1 << (nbits - 1)) == step ? step : (1 << nbits); +} + +/// Get the operation which is not a read-write in current queue +/// \param [in, out] op +Operation *getNotReadWriteOperaiton(std::queue &tmpQ) { + Operation *op = nullptr; + while (!tmpQ.empty()) { + Operation *cur = tmpQ.front(); + tmpQ.pop(); + if (isReadOrWriteOperation(cur)) + continue; + + op = cur; + } + return op; +} + +/// operation should not contain for loop +bool is_innermost_operation(Operation *op) { + bool inner_most = true; + op->walk([&inner_most](Operation *p) { + if (isa(p)) { + inner_most = false; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return inner_most; +} + +/// whether operate on last dimension +bool isLastDim(const AffineExpr &expr, const size_t rank) { + return isa(expr) && + dyn_cast(expr).getPosition() == rank - 1; +} + +bool isReadWriteOnLastDim(Operation *op) { + if (isReadOrWriteOperation(op)) { + AffineMap permutationMap = + dyn_cast(op) + ? cast(op).getPermutationMap() + : cast(op).getPermutationMap(); + int64_t rank = + dyn_cast(op) + ? cast(op->getOperand(0).getType()).getRank() + : cast(op->getOperand(1).getType()).getRank(); + ArrayRef dimExpr = permutationMap.getResults(); + bool find = false; + for (const auto &expr : dimExpr) + if (isLastDim(expr, rank)) { + find = true; + break; + } + + return find; + } + llvm::llvm_unreachable_internal( + "The operation is not a read or write operation."); + return false; +} + +// Filter out the operations that can be vectorized. We are only interested in +// operations that do not contain any for loops(innermost IR). +[[nodiscard]] bool filterOperation(Operation *op) { + if (!is_innermost_operation(op)) { + return false; + } + + // We are only interested about the operation in vector dialect + if (failed(getOperationVectorType(op))) { + return false; + } + + // We don't need to vectorize the constant operation + if (isa(op)) { + return false; + } + + if (isReadOrWriteOperation(op) and !isReadWriteOnLastDim(op)) { + return false; + } + + return true; +} + +VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) { + // Check that the operation type can be broken + // down into a loop. + mlir::FailureOr baseType = getOperationVectorType(op); + if (failed(baseType)) { + assert(0 && "Failed to get vector type for operation"); + return VectorType(); + } + auto vectorizedType = baseType.value(); + if (loopStep == 0) + loopStep = getDataTypeValidSteps(vectorizedType); + + return VectorType::get({loopStep}, vectorizedType.getElementType()); +} + +int TypeHelper::generateValidSteps(int steps, VectorType type) { + if (type.getShape().back() >= steps) + return steps; + int evenStep = getNearestVectorStep(type.getShape().back()); + auto typebits = type.getElementTypeBitWidth(); + return evenStep * typebits >= 128 ? evenStep : 1; +} + +// Get the maximum number of current data types that a register can hold +[[nodiscard]] int TypeHelper::getDataTypeMAXSIMDLength(VectorType type) { + auto typebits = type.getElementTypeBitWidth(); + const int favx512bits = 512; + const int favx2bits = 256; + if (info.favx512f) + return favx512bits / typebits; + + if (info.favx2) + return favx2bits / typebits; + + // invalid hardware + assert(false && "Invalid hardware."); + return -1; +} + +/// Get a appropriate for loop step for current vector type +[[nodiscard]] int TypeHelper::getDataTypeValidSteps(VectorType type) { + return generateValidSteps(getDataTypeMAXSIMDLength(type), type); +} + +/// default op1 is previous operation +bool GroupOperationFusion::isCompatibleVectorType(Operation *op1, + Operation *op2) { + // only lower to vector pass can produce read operation. In general two read + // operation is compatible + if (isa(op1) and isa(op2)) { + return true; + } + + mlir::FailureOr type1 = getOperationVectorType(op1, true); + mlir::FailureOr type2 = getOperationVectorType(op2, false); + // some operation has two different operands type like multireduction, we need + // to check whether compitable with accumulate vector + VectorType suppleType; + if (failed(type1) || failed(type2)) + return false; + + auto sp1 = type1.value(); + auto sp2 = type2.value(); + + auto isCompatible = [](VectorType sp1, VectorType sp2) { + bool isCompatible = true; + auto min_rank = std::min(sp1.getRank(), sp2.getRank()); + // from front to back + for (long i = 0; i < min_rank; i++) { + if (sp1.getDimSize(i) != sp2.getDimSize(i)) { + isCompatible = false; + break; + } + } + return isCompatible; + }; + + bool result; + result = isCompatible(sp1, sp2); + // operand check only happen on later operation is op2 + // TODO: may need to support other similar operation like multireduction has + // two different operands type + if (isa(op2)) { + suppleType = cast(op2->getOperandTypes()[1]); + result |= isCompatible(suppleType, sp1); + } + + return result; +} + +void GroupOperationFusion::updateGroupBigestVectorType(VectorType vectorType) { + int64_t rank = vectorType.getRank(); + llvm::SmallDenseMap &groupVectorType = + getGroupBiggestRankVectorType(); + + if (groupVectorType.contains(opGroups.size() - 1)) { + VectorType bigestType = groupVectorType[opGroups.size() - 1]; + if (bigestType.getRank() < rank) + groupVectorType[opGroups.size() - 1] = vectorType; + + return; + } + + groupVectorType[opGroups.size() - 1] = vectorType; +} + +void GroupOperationFusion::addOperationToGroup(Operation *op) { + assert(op); + VectorType vectorType = getOperationMaxVectorType(op).value(); + if (isNeedNewGroup(op)) + opGroups.emplace_back(std::queue()); + + if (not isa(op)) { + updateGroupBigestVectorType(vectorType); + while (not notNeedToJudgeOps.empty()) { + auto cur = notNeedToJudgeOps.front(); + notNeedToJudgeOps.pop(); + opGroupIndexMap[cur] = opGroups.size() - 1; + opGroups.back().push(cur); + } + opGroups.back().push(op); + opGroupIndexMap[op] = opGroups.size() - 1; + } + opAnchorPos[op] = getOperationMaxVectorType(op)->getRank() - 1; +} + +// We classify the operations we are interested in after filtering. Operations +// of in the same group have no data dependencies. Those operations can generate +// a same outter for loop. +void GroupOperationFusion::classifyOperations() { + // dummpy + if (opGroups.empty()) + opGroups.emplace_back(std::queue()); + + func::FuncOp func = getFunction(); + + func->walk([&](Operation *op) { + if (filterOperation(op)) { + addOperationToGroup(op); + return WalkResult::advance(); + } + if (isNotNeedToProcessOp(op) and !opGroups.back().empty()) + opGroups.emplace_back(std::queue()); + + return WalkResult::advance(); + }); + // init operations results and initialization args + groupOpResults.clear(); + groupOpInitArgs.clear(); + for (size_t i = 0; i < opGroups.size(); i++) { + groupOpResults.emplace_back( + llvm::MapVector>()); + groupOpInitArgs.emplace_back(SetVector()); + } +} + +void GroupOperationFusion::run() { classifyOperations(); } + +bool GroupOperationFusion::isNeedNewGroup(Operation *op) { + if (isa(op)) { + notNeedToJudgeOps.push(op); + return false; + } + // 1. check previous operation + if (!opGroups.back().empty()) { + // We only care about the calculation operation. + std::queue tmpQ(opGroups.back()); + Operation *prevOp = nullptr; + prevOp = getNotReadWriteOperaiton(tmpQ); + if (!prevOp) { + // if previous operation is not in the same block, we need to create a + // group + return opGroups.back().back()->getParentOp() != op->getParentOp() or + isSpecialOp(op); + } + + if (prevOp->getParentOp() != op->getParentOp()) + return true; + + // special operation need to check data dependency axis + if (hasDataDependency(prevOp, op)) + return true; + + // previous operation vector type is not compatible with current operation + if (!isCompatibleVectorType(prevOp, op)) + return true; + } + return false; +} + +void GroupOperationAnalysis::analysisEmptyGroup() { + SmallVector, 8> &opGroups = + fusionStrategy.getOpGroups(); + SmallVector>, 8> + &groupOpResults = fusionStrategy.getGroupOpResults(); + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + if (grp.empty()) + continue; + if (groupOpResults[idx].empty()) + std::queue().swap(grp); + } +} + +void GroupOperationAnalysis::analysisGroupMaxSteps() { + auto &opGroups = fusionStrategy.getOpGroups(); + + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + + uint32_t steps = std::numeric_limits::max(); + + llvm::SmallVector &grpSteps = + fusionStrategy.getGroupMaxSteps(); + while (idx + 1 > grpSteps.size()) + grpSteps.emplace_back(steps); + + std::queue tmpQueue(grp); + auto calculateOpSteps = [&](Type type) { + auto opType = dyn_cast(type); + if (opType) + steps = std::min(steps, (uint32_t)fusionStrategy.getTypeHelper() + .getDataTypeValidSteps(opType)); + }; + while (!tmpQueue.empty()) { + auto op = tmpQueue.front(); + tmpQueue.pop(); + if (isa(op)) + calculateOpSteps(op->getOperandTypes()[0]); + + calculateOpSteps(getOperationVectorType(op).value()); + } + grpSteps[idx] = steps; + } +} +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 5459bf1ec..888af1c79 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -5,13 +5,12 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "TilingVector.h" +#include "TilingVector.hpp" namespace mlir { namespace gc { #define GEN_PASS_DEF_CPUPHYSICALREGISTERPASS #include "gc/Transforms/Passes.h.inc" -namespace { #define DEBUG_TYPE "lower-to-physical-register-pass" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") @@ -23,41 +22,28 @@ namespace { arith::FPToSIOp, arith::FPToUIOp, arith::SIToFPOp, arith::UIToFPOp, \ arith::TruncFOp, arith::TruncIOp -#define NOT_NEED_TO_PROCESS_OP \ - linalgx::BatchReduceMatmulVnniOp, linalgx::MultiBatchMatmulOp, \ - linalg::BatchReduceMatmulOp, linalgx::Mm2DVnniOp, linalgx::Mm4DVnniOp, \ - linalg::MatmulOp, linalg::BatchMatmulOp, \ - linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ - linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ - linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ - tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::ExtractSliceOp, \ - tensor::InsertSliceOp, microkernel::BrgemmOp - /// TODO: remove it in the future +bool enableDebugPrinter = true; bool disableSpecialOp = false; -bool disableBroadcastOp = false; -bool enableDebugPrinter = false; void printQueue(const std::queue &opQueue) { auto tempQ(opQueue); while (!tempQ.empty()) { auto cur = tempQ.front(); - cur->dump(); + LDBG(*cur); tempQ.pop(); } } void printGroupOps(SmallVector, 8> &opGroups) { for (auto [idx, grp] : llvm::enumerate(opGroups)) { - llvm::outs() << " group id: " << idx << "\n"; - if (grp.empty()) { + LDBG("group id: " << idx); + if (grp.empty()) continue; - } - llvm::outs() << "__________________ group start_____________" - << "\n"; + + LDBG("__________________ group start_____________"); printQueue(grp); - llvm::outs() << "__________________ group end_____________" - << "\n"; + LDBG("__________________ group end_______________"); } } @@ -70,24 +56,10 @@ static inline bool isCandidateMoveOperations(Operation *op) { op); } -static inline bool isNotNeedToProcessOp(Operation *op) { - return isa(op); -} - static inline bool isReadOrWriteOperation(Operation *op) { return isa(op); } -/// whether op2 use op1 result -/// Currently we just enable this function for write and read operation -template || - std::is_same_v, - T>> -static bool isOperationsHasDefUseRelation(Operation *op1, Operation *op2) { - return llvm::any_of(op2->getOperands(), - [&op1](Value opd) { return opd.getDefiningOp() == op1; }); -} /// Get the index position of the first element that is true static size_t getFirstTrueIndex(ArrayRef ararys) { for (size_t i = 0; i < ararys.size(); i++) @@ -123,87 +95,12 @@ Value findOriginalTensor(Value writeTensor, Block *block) { return writeTensor; } -/// operation should not contain for loop -bool is_innermost_operation(Operation *op) { - bool inner_most = true; - op->walk([&inner_most](Operation *p) { - if (isa(p)) { - inner_most = false; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return inner_most; -} - /// whether operation is a not support operation bool isNotSupportOperation(Operation *op) { return isa(op); } -/// Get vector type of the operation \param op -/// \param isPrevOp whether the operation is a previous operation, if it is not -/// prev-op, may need to use result vectortype -/// default will return the opeation result type -mlir::FailureOr getOperationVectorType(Operation *op, - bool isPrevOp = true) { - if (not op) - return failure(); - - auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; - auto ret = - TypeSwitch>(op) - .Case( - [&](vector::TransferWriteOp transferWriteOp) - -> mlir::FailureOr { - if (auto retType = dyn_cast( - transferWriteOp.getOperandTypes()[0])) - return retType; - - LDBG("TransferWrite Operation has wrong vector to write."); - return failure(); - }) - .Case( - [&](vector::TransferReadOp transferReadOp) - -> mlir::FailureOr { - return transferReadOp.getVectorType(); - }) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - if (isPrevOp) - return cast( - multiReductionOp->getResultTypes()[0]); - - // TODO: may need to add accumulate value vectortype - return cast(multiReductionOp.getSourceVectorType()); - }) - .Default([&](Operation *op) -> mlir::FailureOr { - if (isPrevOp) { - if (op->getResultTypes().empty()) - return failure(); - - if (auto shapedType = - dyn_cast(op->getResultTypes()[0])) - return shapedType; - - return failure(); - } - if (op->getOperandTypes().empty()) - return failure(); - - if (auto shapedType = - dyn_cast(op->getOperandTypes()[0])) { - return shapedType; - } - return failure(); - }); - if (!failed(ret) and isDynamicType(ret.value())) { - return failure(); - } - return ret; -} - /// whether the vector operation is operate on dynamic shape bool hasDynamicShape(Operation *op) { if (failed(getOperationVectorType(op))) { @@ -249,24 +146,6 @@ bool hasNotSupportOperation(func::FuncOp *func) { return walkRes != WalkResult::advance(); } -/// select nearest even step -int getNearestVectorStep(const int step) { - assert(step > 0); - int nbits = 0, n = step; - while (n) { - n = n >> 1; - nbits++; - } - assert(nbits <= 6 || (nbits == 7 && step == 64)); - return (1 << (nbits - 1)) == step ? step : (1 << nbits); -} - -/// whether operate on last dimension -bool isLastDim(const AffineExpr &expr, const size_t rank) { - return isa(expr) && - dyn_cast(expr).getPosition() == rank - 1; -} - void GenerateLoopHelper::setNextAnchorArgs( DenseMap &nextAnchorArgsIdxMap, SmallVector &nextAnchorArgs) { @@ -338,36 +217,6 @@ void GenerateLoopHelper::updateCurrentArgsStatus( loopArgsOriginalOperandMap = argsOriginalMap; } -int TypeHelper::generateValidSteps(int steps, VectorType type) { - if (type.getShape().back() >= steps) - return steps; - int evenStep = getNearestVectorStep(type.getShape().back()); - auto typebits = type.getElementTypeBitWidth(); - return evenStep * typebits >= 128 ? evenStep : 1; -} - -// Get the maximum number of current data types that a register can hold -[[nodiscard]] int TypeHelper::getDataTypeMAXSIMDLength(VectorType type) { - auto typebits = type.getElementTypeBitWidth(); - const int favx512bits = 512; - const int favx2bits = 256; - if (HWInfo.favx512f) - return favx512bits / typebits; - - if (HWInfo.favx2) - return favx2bits / typebits; - - // invalid hardware - LDBG("Please check the hardware information."); - assert(false && "Invalid hardware."); - return -1; -} - -/// Get a appropriate for loop step for current vector type -[[nodiscard]] int TypeHelper::getDataTypeValidSteps(VectorType type) { - return generateValidSteps(getDataTypeMAXSIMDLength(type), type); -} - /// get float or integer dense attribute /// \param [in,out] attr template @@ -395,70 +244,6 @@ FailureOr createArithSplatConstantOp(IRRewriter &rewriter, return rewriter.create(loc, attr)->getResults()[0]; } -/// get operation vector type -/// \param isPrevOp whether the operation is a previous operation, if it is not -/// prev-op, may need to use result vectortype -/// default will return the opeation result type -mlir::FailureOr getOperationMaxVectorType(Operation *op) { - if (not op) - return failure(); - - auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; - auto ret = - TypeSwitch>(op) - .Case( - [&](vector::TransferWriteOp transferWriteOp) - -> mlir::FailureOr { - if (auto retType = - cast(transferWriteOp.getOperandTypes()[0])) - return retType; - return failure(); - }) - .Case( - [&](vector::TransferReadOp transferReadOp) - -> mlir::FailureOr { - return transferReadOp.getVectorType(); - }) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - return cast(multiReductionOp.getSourceVectorType()); - }) - .Default([&](Operation *op) -> mlir::FailureOr { - if (op->getResultTypes().empty() and op->getOperandTypes().empty()) - return failure(); - - if (op->getResultTypes().empty()) - return cast(op->getOperandTypes()[0]); - - if (op->getOperandTypes().empty()) - return cast(op->getResultTypes()[0]); - - auto opdType = cast(op->getOperandTypes()[0]); - auto retType = cast(op->getResultTypes()[0]); - return opdType.getRank() > retType.getRank() ? opdType : retType; - }); - if (!failed(ret) and isDynamicType(ret.value())) - return failure(); - - return ret; -} - -VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) { - // Check that the operation type can be broken - // down into a loop. - mlir::FailureOr baseType = getOperationVectorType(op); - if (failed(baseType)) { - LDBG("Failed to get vector type for operation: " << *op << "\n"); - assert(0 && "Failed to get vector type for operation"); - return VectorType(); - } - auto vectorizedType = baseType.value(); - if (loopStep == 0) - loopStep = getDataTypeValidSteps(vectorizedType); - - return VectorType::get({loopStep}, vectorizedType.getElementType()); -} - /// whether the operation result need to be returned /// \param anchorIdx resuilt produce operation anchor position /// \param retType resuilt return type @@ -579,31 +364,6 @@ float bfloat2float(uint16_t bfloatBits) { return floatBits.f; } -bool isReadWriteOnLastDim(Operation *op) { - if (isReadOrWriteOperation(op)) { - AffineMap permutationMap = - dyn_cast(op) - ? cast(op).getPermutationMap() - : cast(op).getPermutationMap(); - int64_t rank = - dyn_cast(op) - ? cast(op->getOperand(0).getType()).getRank() - : cast(op->getOperand(1).getType()).getRank(); - ArrayRef dimExpr = permutationMap.getResults(); - bool find = false; - for (const auto &expr : dimExpr) - if (isLastDim(expr, rank)) { - find = true; - break; - } - - return find; - } - LDBG("The operation is not a read or write operation." << *op << "\n"); - assert(0 && "The operation is not a read or write operation."); - return false; -} - std::variant numeric_limits_minimum(Type type) { Type t1 = getElementTypeOrSelf(type); if (t1.isF32()) { @@ -805,7 +565,7 @@ Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { /*inBounds=*/inBoundsVal); } -Operation *CanonicalizerCommonUsedData::createTransferReadOpBefore( +Operation *GroupOperationFusionImpl::createTransferReadOpBefore( Operation *op, const Value &operand, vector::TransferReadOp *srcReadOp) { auto operandType = cast(operand.getType()); @@ -829,9 +589,11 @@ Operation *CanonicalizerCommonUsedData::createTransferReadOpBefore( /*indices=*/SmallVector(operandType.getRank(), zero), /**affinemap*/ srcReadOpAffineMap, /*inBounds=*/inBoundsVal); - DenseMap &permutationMap = getOpPermuationMap(); + DenseMap &permutationMap = + getGroupOperationFusion().getOpPermuationMap(); permutationMap[t] = srcReadOpAffineMap; - getFusionStrategy().getOpAnchorPos()[t] = t.getVectorType().getRank() - 1; + getGroupOperationFusion().getOpAnchorPos()[t] = + t.getVectorType().getRank() - 1; return t; } @@ -844,9 +606,11 @@ Operation *CanonicalizerCommonUsedData::createTransferReadOpBefore( /*indices=*/SmallVector(operandType.getRank(), zero), /**affinemap*/ padValue, /*inBounds=*/inBoundsVal); - DenseMap &permutationMap = getOpPermuationMap(); + DenseMap &permutationMap = + getGroupOperationFusion().getOpPermuationMap(); permutationMap[t] = t.getPermutationMap(); - getFusionStrategy().getOpAnchorPos()[t] = t.getVectorType().getRank() - 1; + getGroupOperationFusion().getOpAnchorPos()[t] = + t.getVectorType().getRank() - 1; return t; } @@ -861,7 +625,7 @@ canonicalizeSourceOperation(Operation *op, return std::make_pair(resultTensor, writeOp->getResults()[0]); } -[[nodiscard]] Value CanonicalizerCommonUsedData::canonicalizeCurrentOperation( +[[nodiscard]] Value GroupOperationFusionImpl::canonicalizeCurrentOperation( Operation *op, const Value &transferReadOperand, size_t operandIdx, vector::TransferReadOp *srcReadOp) { // transfer_read operation @@ -963,7 +727,7 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, void ForLoopGenerator::moveOperationsToCurrentForBody( const OpBuilder &b, std::queue &opQueue, GenerateLoopHelper &loopHelperParam) { - auto &opPermuationMap = getOpPermuationMap(); + auto &opPermuationMap = getVectorBasedFusion().getOpPermuationMap(); auto tmpQ(opQueue); while (!tmpQ.empty()) { auto x = tmpQ.front(); @@ -981,7 +745,7 @@ void ForLoopGenerator::getResultInCurrentOps( DenseMap &forResultOrignalResultMap) { auto tmpQ(ops); llvm::MapVector> &groupResults = - getGroupOpResults()[groupId]; + getVectorBasedFusion().getGroupOpResults()[groupId]; while (!tmpQ.empty()) { Operation *cur = tmpQ.front(); tmpQ.pop(); @@ -1025,8 +789,9 @@ void ForLoopGenerator::getInitArgsToNextAnchor( SmallVector &nextAnchorArgs, GenerateLoopHelper &loopHelperParam) { DenseMap &opAnchorPos = - getFusionStrategy().getOpAnchorPos(); - SetVector &opInitArgs = getGroupOpInitArgs()[loopHelperParam.groupIdx]; + getVectorBasedFusion().getOpAnchorPos(); + SetVector &opInitArgs = + getVectorBasedFusion().getGroupOpInitArgs()[loopHelperParam.groupIdx]; DenseSet visited; // find the next anchor arguments @@ -1061,7 +826,7 @@ void ForLoopGenerator::getOperationInCurrentAnchor( std::queue &toQueue) { while (!fromQueue.empty()) { Operation *curOp = fromQueue.front(); - if (anchorIdx == getFusionStrategy().getOpAnchorPos()[curOp]) { + if (anchorIdx == getVectorBasedFusion().getOpAnchorPos()[curOp]) { toQueue.push(curOp); fromQueue.pop(); continue; @@ -1158,7 +923,7 @@ void ForLoopGenerator::generateLoopResults( currentResultMap); llvm::MapVector> &groupResults = - getGroupOpResults()[loopHelperParam.groupIdx]; + getVectorBasedFusion().getGroupOpResults()[loopHelperParam.groupIdx]; // check for yield results whether need to return to next anchor for (auto [idx, forResult] : llvm::enumerate(loopHelperParam.nextAnchorResults)) { @@ -1201,14 +966,14 @@ void updateLoopArgsData(Value val, Value originalVal, originalOperandLoopArgsMap[originalVal] = val; } -scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( +scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop( OpBuilder &opBuilder, const size_t reductionIdx, GenerateLoopHelper &loopHelperParam) { MultiReductionCanonicalizer rdCanonicalizer = getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; - VectorFusionStrategy &fusionStrategy = getFusionStrategy(); + GroupOperationFusion &fusionStrategy = getVectorBasedFusion(); SmallVector, 8> &opGroups = fusionStrategy.getOpGroups(); @@ -1219,8 +984,8 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( bool lastDimReduction = rdCanonicalizer.hasLastDimReduction(); VectorType vectorType = rdCanonicalizer.getSourceType(); const int loopStep = - getFusionStrategy().getGroupMaxSteps()[loopHelperParam.groupIdx]; - + getVectorBasedFusion().getGroupMaxSteps()[loopHelperParam.groupIdx]; + func::FuncOp func = fusionStrategy.getFunction(); IRRewriter rewriterOfFunc(func); Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); @@ -1259,12 +1024,9 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( nextAnchorArgs); // replace reduction init args - if (loopHelperParam.originalOperandLoopArgsMap.contains( - multireductionOp.getAcc())) { - size_t accValIdx = - loopHelperParam.currentLoopStateIdxMap - [loopHelperParam.originalOperandLoopArgsMap[multireductionOp - .getAcc()]]; + if (currentoriginalArgsMap.contains(multireductionOp.getAcc())) { + size_t accValIdx = currentArgsIdxMap + [currentoriginalArgsMap[multireductionOp.getAcc()]]; updateCurrentArgsStatus( loopState, accValIdx, nextAnchorArgs, multireductionOp.getAcc(), nextAnchorArgsIdxMap, originalArgsMap, argsOriginalMap); @@ -1373,7 +1135,7 @@ scf::ForOp ForLoopGenerator::reductionAxisGenerateForLoop( return forOp; } -void ForLoopGenerator::ensureAccInParallelLoop( +void LoopGeneratorImpl::ensureAccInParallelLoop( GenerateLoopHelper &loopHelperParam, ArrayRef parallelAxis, Value multiReductionAcc, DenseMap &nextAnchorArgsIdxMap, SmallVector &nextAnchorArgs) { @@ -1421,20 +1183,22 @@ void ForLoopGenerator::ensureAccInParallelLoop( /// Generate for loop for parallel axis of `vector.multi_reduction`. /// This function also call reduction axis for loop -scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( +scf::ForOp LoopGeneratorImpl::parallelAxisGenerateForLoop( OpBuilder &opBuilder, GenerateLoopHelper &loopHelperParam) { MultiReductionCanonicalizer &rdCanonicalizer = getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; vector::MultiDimReductionOp &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; VectorType vectorType = rdCanonicalizer.getSourceType(); + GroupOperationFusion &fusionStrategy = getVectorBasedFusion(); + func::FuncOp func = fusionStrategy.getFunction(); IRRewriter rewriterOfFunc(func); SmallVector ¶llelAxis = rdCanonicalizer.getParallelAxis(); const Location &loc = multiReductionOp.getLoc(); Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); size_t grpMaxStep = - getFusionStrategy().getGroupMaxSteps()[loopHelperParam.groupIdx]; + getVectorBasedFusion().getGroupMaxSteps()[loopHelperParam.groupIdx]; size_t actualStep = (loopHelperParam.anchorIdx == parallelAxis.size() - 1 ? grpMaxStep : 1); Value forSteps = makeIndexArithConstantOp(opBuilder, loc, actualStep); @@ -1442,7 +1206,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( // last dim reduction need to a generate dim=16 loop for fused with pre-op int dimSize = 0; if (loopHelperParam.anchorIdx == parallelAxis.size()) - dimSize = getFusionStrategy().getGroupMaxSteps()[loopHelperParam.groupIdx]; + dimSize = + getVectorBasedFusion().getGroupMaxSteps()[loopHelperParam.groupIdx]; else dimSize = vectorType.getShape()[parallelAxis[loopHelperParam.anchorIdx]]; @@ -1452,7 +1217,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( loc, zero, numIter, forSteps, loopHelperParam.loopIterArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { loopHelperParam.inductionVars.emplace_back(iv); - VectorFusionStrategy &fusionStrategy = getFusionStrategy(); + DenseMap &opIndexMap = fusionStrategy.getOpGroupIndexMap(); @@ -1522,7 +1287,8 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( auto accVal = b.create( loc, DenseElementsAttr::get( - getVectorzedType(multiReductionOp, dimSize), + fusionStrategy.getTypeHelper().getVectorzedType( + multiReductionOp, dimSize), {initValueAttr})); // put accumulte val at first for loop args @@ -1597,7 +1363,7 @@ scf::ForOp ForLoopGenerator::parallelAxisGenerateForLoop( }); } -scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( +scf::ForOp LoopGeneratorImpl::generateTransposeForLoopWithLastDim( OpBuilder &opBuilder, const int tpSteps, const Location &loc, Operation *successorWriteOp, GenerateLoopHelper &loopHelperParam) { auto &tpCanonicalizer = @@ -1670,7 +1436,8 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoopWithLastDim( void ForLoopGenerator::prepareForLoopArgs(const size_t grpIdx, GenerateLoopHelper &loopHelper) { - SetVector &grpArgs = getGroupOpInitArgs()[grpIdx]; + SetVector &grpArgs = + getVectorBasedFusion().getGroupOpInitArgs()[grpIdx]; loopHelper.loopIterArgs = grpArgs.getArrayRef(); for (auto [idx, val] : llvm::enumerate(grpArgs)) { loopHelper.currentLoopStateIdxMap[val] = idx; @@ -1679,7 +1446,7 @@ void ForLoopGenerator::prepareForLoopArgs(const size_t grpIdx, } } -void ForLoopGenerator::rearrageMultiReductionIR( +void LoopGeneratorImpl::rearrageMultiReductionIR( const size_t grpIdx, DenseMap> &indiceLoopMap) { MultiReductionCanonicalizer &rdCanonicalizer = @@ -1693,7 +1460,8 @@ void ForLoopGenerator::rearrageMultiReductionIR( std::queue &accRelatedOps = rdCanonicalizer.getAccRelatedOps(); std::queue &sourceRelatedOps = rdCanonicalizer.getSourceRelatedOps(); - std::queue &opQueue = getFusionStrategy().getOpGroups()[grpIdx]; + std::queue &opQueue = + getVectorBasedFusion().getOpGroups()[grpIdx]; auto copyOpQueue(opQueue); getPrevOps(prevOps, copyOpQueue, multiReductionOp); getPostOps(postOps, copyOpQueue, multiReductionOp); @@ -1704,7 +1472,7 @@ void ForLoopGenerator::rearrageMultiReductionIR( std::queue tmpSourceQ(sourceRelatedOps); DenseMap varLoopIdxMap; VectorType groupVector = - getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; + getVectorBasedFusion().getGroupBiggestRankVectorType()[grpIdx]; for (size_t i = 0; i < parallelAxis.size(); i++) { varLoopIdxMap[parallelAxis[i]] = i; } @@ -1744,7 +1512,7 @@ void ForLoopGenerator::replaceOpUsersWithForLoopResult( scf::ForOp forOp, int grpIdx, SmallVector &nextAnchorResults, DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap) { - IRRewriter rewriter(func); + IRRewriter rewriter(forOp); DenseSet forOpChildOps; forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); auto replaceIfFn = [&](OpOperand &use) { @@ -1760,7 +1528,7 @@ void ForLoopGenerator::replaceOpUsersWithForLoopResult( } } scf::ForOp -ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { +LoopGeneratorImpl::generateMultiReductionForLoop(const size_t grpIdx) { DenseMap> indiceLoopMap; rearrageMultiReductionIR(grpIdx, indiceLoopMap); @@ -1780,16 +1548,16 @@ ForLoopGenerator::generateMultiReductionForLoop(const size_t grpIdx) { loopHelper.nextAnchorResultsIdxMap, loopHelper.nextAnchorResultOrignalResultMap); - IRRewriter rewriter(func); vector::MultiDimReductionOp multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; + IRRewriter rewriter(multiReductionOp); rewriter.eraseOp(multiReductionOp); return forOp; } // generate simple data movement for loop -scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( +scf::ForOp LoopGeneratorImpl::generateTransposeScalarDataMovement( OpBuilder &opBuilder, const Location &loc, DenseMap &tpAxisMap, GenerateLoopHelper &loopHelperParam) { auto &tpCanonicalizer = @@ -1864,7 +1632,7 @@ scf::ForOp ForLoopGenerator::generateTransposeScalarDataMovement( }); } -scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( +scf::ForOp LoopGeneratorImpl::generateShapeCastReadWriteLoop( OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const size_t steps, const Location &loc, SmallVector &inductionVars, ValueRange iterArgs) { @@ -1876,7 +1644,7 @@ scf::ForOp ForLoopGenerator::generateShapeCastReadWriteLoop( sourceType.getRank() > destType.getRank() ? sourceType : destType; size_t rank = loopType.getRank(); DenseMap &opIndexMap = - getFusionStrategy().getOpGroupIndexMap(); + getVectorBasedFusion().getOpGroupIndexMap(); auto zero = makeIndexArithConstantOp(opBuilder, loc, 0); bool isLastDim = loopType.getRank() - 1 == (int64_t)forDimIdx; @@ -2026,7 +1794,7 @@ void ForLoopGenerator::rectifyReadOperationIndice( // currently only broadcast (fuse as transfer_read) will move into more inner // loop if (readTensorType.getRank() - 1 >= - (int64_t)getFusionStrategy().getOpAnchorPos()[*originalReadOp]) + (int64_t)getVectorBasedFusion().getOpAnchorPos()[*originalReadOp]) return; int64_t itrIdx = loopType.getRank() - 1; @@ -2041,7 +1809,7 @@ void ForLoopGenerator::rectifyReadOperationIndice( } /// generate transpose for loop -scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { +scf::ForOp LoopGeneratorImpl::generateShapeCastForLoop(const size_t grpIdx) { ShapeCastCanonicalizer &scCanonicalizer = getShapeCastCanonicalizers()[grpIdx]; @@ -2050,7 +1818,7 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { VectorType sourceType = scOp.getSourceVectorType(); VectorType destType = scOp.getResultVectorType(); DenseMap &opIndexMap = - getFusionStrategy().getOpGroupIndexMap(); + getVectorBasedFusion().getOpGroupIndexMap(); OpBuilder b(scOp); SmallVector iterArgs; @@ -2064,8 +1832,8 @@ scf::ForOp ForLoopGenerator::generateShapeCastForLoop(const size_t grpIdx) { iterArgs.emplace_back(successorWriteOp->getOperands()[1]); SmallVector inductionVars; - IRRewriter rewriter(func); - const size_t groupStep = getFusionStrategy().getGroupMaxSteps()[grpIdx]; + IRRewriter rewriter(scOp); + const size_t groupStep = getVectorBasedFusion().getGroupMaxSteps()[grpIdx]; bool isSourceMultiple = sourceType.getShape()[sourceType.getRank() - 1] % groupStep == 0; @@ -2100,7 +1868,7 @@ void ForLoopGenerator::getCurrentGroupIndiceLoopMap( if (setIdxMap.empty()) { DenseMap forIdxMap; VectorType groupVector = - getFusionStrategy().getGroupBiggestRankVectorType()[groupId]; + getVectorBasedFusion().getGroupBiggestRankVectorType()[groupId]; for (size_t i = 0; (int64_t)i < groupVector.getRank(); i++) { forIdxMap[i] = i; } @@ -2111,16 +1879,16 @@ void ForLoopGenerator::getCurrentGroupIndiceLoopMap( } void ForLoopGenerator::clearCurrentOperationGroup(size_t grpIdx) { - std::queue().swap(getFusionStrategy().getOpGroups()[grpIdx]); + std::queue().swap(getVectorBasedFusion().getOpGroups()[grpIdx]); }; -scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { +scf::ForOp LoopGeneratorImpl::generateTransposeForLoop(const size_t grpIdx) { // transpose rank must bigger than 2 TransposeCanonicalizer &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; - IRRewriter rewriter(func); + IRRewriter rewriter(tpOp); VectorType vtType = tpOp.getResultVectorType(); size_t rank = vtType.getRank(); @@ -2136,8 +1904,9 @@ scf::ForOp ForLoopGenerator::generateTransposeForLoop(const size_t grpIdx) { bool isTwoDTranspose = tpCanonicalizer.isTwoDTranspose(); Operation *successorWriteOp = - getNextTargetOperationInCurrentGroup(tpOp, - grpIdx); + getVectorBasedFusion() + .getNextTargetOperationInCurrentGroup( + tpOp, grpIdx); DenseMap operandIdxMap; DenseMap originalOperandMap, operandOriginalMap, resultIdxMap, @@ -2356,26 +2125,26 @@ void addDummyInit(SmallVector &canonicalizer, size_t steps = 1) { canonicalizer.emplace_back(T({}, steps)); }; -void CanonicalizerVectorOperation::clearSpecialOperationCanonicalizers() { +void LoopGeneratorImpl::clearSpecialOperationCanonicalizers() { getMultiRdCanonicalizers().clear(); getBroadcastCanonicalizers().clear(); getTransposeCanonicalizers().clear(); getShapeCastCanonicalizers().clear(); } -void CanonicalizerVectorOperation::dummyInitSpecialOperation(size_t steps) { +void LoopGeneratorImpl::dummyInitSpecialOperation(size_t steps) { addDummyInit(getMultiRdCanonicalizers(), steps); addDummyInit(getBroadcastCanonicalizers(), steps); addDummyInit(getTransposeCanonicalizers(), steps); addDummyInit(getShapeCastCanonicalizers(), steps); } -void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { +void LoopGeneratorImpl::initSpeicalOperationCanonicalizers() { clearSpecialOperationCanonicalizers(); SmallVector, 8> &opGroups = - getFusionStrategy().getOpGroups(); + getVectorBasedFusion().getOpGroups(); for (auto [idx, grp] : llvm::enumerate(opGroups)) { - dummyInitSpecialOperation(getFusionStrategy().getGroupMaxSteps()[idx]); + dummyInitSpecialOperation(getVectorBasedFusion().getGroupMaxSteps()[idx]); if (grp.empty()) continue; @@ -2408,7 +2177,7 @@ void CanonicalizerVectorOperation::initSpeicalOperationCanonicalizers() { } template -void CanonicalizerVectorOperation::processSpecialOperation( +void LoopGeneratorImpl::processSpecialOperation( T &canonicalizers, const std::function &generateFunc) { for (auto [groupId, canonicalizer] : llvm::enumerate(canonicalizers)) { SmallVector &ops = canonicalizer.getCandidateOps(); @@ -2418,8 +2187,7 @@ void CanonicalizerVectorOperation::processSpecialOperation( } } -void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { - OpBuilder::InsertionGuard guard(rewriter); +void LoopGeneratorImpl::canonicalizeSpecialOperation() { initSpeicalOperationCanonicalizers(); // traverse all groups @@ -2446,9 +2214,10 @@ void CanonicalizerVectorOperation::canonicalizeSpecialOperation() { [this](const size_t grpIdx) { (void)generateShapeCastForLoop(grpIdx); }); } -void CanonicalizerVectorOperation::run() { - auto &fusionStrategy = getFusionStrategy(); - if (kind == CanonicalizerKind::OperationsGroup) { +void VectorOperationCanonicalizer::run() { + auto &fusionStrategy = fusion.getGroupOperationFusion(); + if (kind == CanonicalizerKind::GroupOperations) { + fusion.run(); // 1. Analysis the operation's operands and results // We need to analyze which operation's result is needed by other // operations, and we need to pass these results correctly. Mapping the @@ -2485,61 +2254,33 @@ void CanonicalizerVectorOperation::run() { // on it. Therefore, `empty tensor`, `transfer_write` and `transfer_read` // need to be inserted at target place. if (enableDebugPrinter) { - printGroupOps(getFusionStrategy().getOpGroups()); - llvm::outs() << "___________ before analysis ________________" - << "\n"; + printGroupOps(fusion.getGroupOperationFusion().getOpGroups()); + LDBG("___________ before analysis ________________"); } - analysisGroupOperaion(); + fusion.canonicalizeEachOperationGroup(); if (enableDebugPrinter) { - llvm::outs() << "___________ after analysis ________________" - << "\n"; - printGroupOps(getFusionStrategy().getOpGroups()); + LDBG("___________ after analysis ________________"); + printGroupOps(fusion.getGroupOperationFusion().getOpGroups()); } + loopGenerator.setVectorBaseFusion(fusion.getGroupOperationFusion()); // Speical Operation Canonicalization - canonicalizeSpecialOperation(); + loopGenerator.canonicalizeSpecialOperation(); // 2.Generate vectorized IR for each operation group for (size_t idx = 0; idx < fusionStrategy.getOpGroups().size(); ++idx) - generateGroupOpVectorizedIR(idx); + loopGenerator.generateGroupOpVectorizedIR(idx); // 3. Some IR cleanup work DominanceInfo domInfo; - eliminateCommonSubExpressions(rewriter, domInfo, func); + eliminateCommonSubExpressions( + rewriter, domInfo, loopGenerator.getVectorBasedFusion().getFunction()); } else { // TODO: need to add directly canonicalize operations logic // generateGroupOpVectorizedIR(idx, grp, fusionStrategy.opGroupIndexMap); } } -// Filter out the operations that can be vectorized. We are only interested in -// operations that do not contain any for loops(innermost IR). -[[nodiscard]] bool filterOperation(Operation *op) { - if (!is_innermost_operation(op)) { - LDBG("Operation is not innermost" << *op << "\n"); - return false; - } - - // We are only interested about the operation in vector dialect - if (failed(getOperationVectorType(op))) { - LDBG("Operation is not in vector dialect" << *op << "\n"); - return false; - } - - // We don't need to vectorize the constant operation - if (isa(op)) { - LDBG("Operation is constantOp" << *op << "\n"); - return false; - } - - if (isReadOrWriteOperation(op) and !isReadWriteOnLastDim(op)) { - LDBG("Operation is not last dim read/write" << *op << "\n"); - return false; - } - - return true; -} - /// void ForLoopGenerator::setOperationCorrectOperand( Operation *op, const DenseMap &opPermuationMap, @@ -2598,9 +2339,9 @@ void ForLoopGenerator::setOperationCorrectOperand( op->setOperand(dim + offset, loopHelperParam.inductionVars[varIdx]); } if (auto readOp = dyn_cast(op)) { - size_t grpIdx = getFusionStrategy().getOpGroupIndexMap()[op]; + size_t grpIdx = getVectorBasedFusion().getOpGroupIndexMap()[op]; VectorType loopType = - getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx]; + getVectorBasedFusion().getGroupBiggestRankVectorType()[grpIdx]; SmallVector readIndices(readOp.getIndices().begin(), readOp.getIndices().end()); rectifyReadOperationIndice(&readOp, loopType, @@ -2613,7 +2354,7 @@ void ForLoopGenerator::setOperationCorrectOperand( scf::ForOp ForLoopGenerator::constructNestedForOp( const size_t groupIdx, OpBuilder &b, const Location &loc, ArrayRef dims, GenerateLoopHelper &loopHelper) { - const int loop_step = getFusionStrategy().getGroupMaxSteps()[groupIdx]; + const int loop_step = getVectorBasedFusion().getGroupMaxSteps()[groupIdx]; // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); auto forSteps = makeIndexArithConstantOp( @@ -2629,7 +2370,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( // inner most body of the loop if (loopHelper.anchorIdx == dims.size() - 1) { std::queue &opQueue = - getFusionStrategy().getOpGroups()[groupIdx]; + getVectorBasedFusion().getOpGroups()[groupIdx]; loopHelper.loopIterArgs = loopState; // 1. get operations in current anchor position std::queue movingOperation; @@ -2662,7 +2403,7 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( std::queue movedQueue; std::queue &opQueue = - getFusionStrategy().getOpGroups()[groupIdx]; + getVectorBasedFusion().getOpGroups()[groupIdx]; SmallVector tmpArgs(loopState); loopHelper.updateDataBeforePreOpMove(tmpArgs, opQueue, movedQueue); movePreOpToCurrentAnchor(b, nextAnchorArgsIdxMap, nextAnchorArgs, @@ -2693,383 +2434,6 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( return forOp; } -/// default op1 is previous operation -bool VectorFusionStrategy::isCompatibleVectorType(Operation *op1, - Operation *op2) { - // only lower to vector pass can produce read operation. In general two read - // operation is compatible - if (isa(op1) and isa(op2)) { - return true; - } - - mlir::FailureOr type1 = getOperationVectorType(op1, true); - mlir::FailureOr type2 = getOperationVectorType(op2, false); - // some operation has two different operands type like multireduction, we need - // to check whether compitable with accumulate vector - VectorType suppleType; - if (failed(type1) || failed(type2)) - return false; - - auto sp1 = type1.value(); - auto sp2 = type2.value(); - - auto isCompatible = [](VectorType sp1, VectorType sp2) { - bool isCompatible = true; - auto min_rank = std::min(sp1.getRank(), sp2.getRank()); - // from front to back - for (long i = 0; i < min_rank; i++) { - if (sp1.getDimSize(i) != sp2.getDimSize(i)) { - isCompatible = false; - break; - } - } - return isCompatible; - }; - - bool result; - result = isCompatible(sp1, sp2); - // operand check only happen on later operation is op2 - // TODO: may need to support other similar operation like multireduction has - // two different operands type - if (isa(op2)) { - suppleType = cast(op2->getOperandTypes()[1]); - result |= isCompatible(suppleType, sp1); - } - - return result; -} - -/// which axis do the shape cast in source shape a -void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, - SmallVector &res) { - unsigned rankA = a.size(); - unsigned rankB = b.size(); - assert(rankA < rankB && "May be invalid shape cast operation."); - - auto isOne = [](int64_t v) { return v == 1; }; - - // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape - // casted to a 0-d vector. - if (rankA == 0 && all_of(b, isOne)) { - for (size_t i = 0; i < a.size(); i++) { - res.emplace_back(i); - } - return; - } - - unsigned i = 0; - unsigned j = 0; - while (i < rankA && j < rankB) { - int64_t dimA = a[i]; - int64_t dimB = 1; - int64_t bAxisBegin = j; - while (dimB < dimA && j < rankB) - dimB *= b[j++]; - if (dimA != dimB) { - assert(false && " Invalid shape cast operation."); - break; - } - if (bAxisBegin != j) { - res.emplace_back(i); - } - ++i; - - // Handle the case when trailing dimensions are of size 1. - // Include them into the contiguous sequence. - if (i < rankA && all_of(a.slice(i), isOne)) - i = rankA; - if (j < rankB && all_of(b.slice(j), isOne)) - j = rankB; - } - - assert(i == rankA && j == rankB && "Invalid shapecast operation."); -} - -bool isScalar(Type type) { - assert(type && "Not a valid type"); - if (auto vecType = dyn_cast(type)) - return false; - if (auto tensorType = dyn_cast(type)) - return false; - return true; -} - -void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output, - SmallVector &bcAxis) { - auto inputShape = input.getShape(); - auto outputShape = output.getShape(); - // following auto_broadcast semantics - const size_t input_rank = inputShape.size(); - const size_t output_rank = outputShape.size(); - assert(output_rank >= input_rank && - "Incorrect input or output shape for broadcast op."); - const size_t offset = output_rank - input_rank; - for (size_t i = 0; i < input_rank; ++i) { - if (inputShape[i] == outputShape[i + offset] || - (ShapedType::isDynamic(inputShape[i]) && - ShapedType::isDynamic(outputShape[i + offset]))) { - bcAxis.emplace_back(i); - } - } - if (bcAxis.empty()) - bcAxis.emplace_back(-1); -} - -void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { - return TypeSwitch(op) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - auto rdDimsRange = multiReductionOp.getReductionDims(); - dataAxis.assign(rdDimsRange.begin(), rdDimsRange.end()); - return; - }) - .Case([&](vector::ShapeCastOp shapeCastOp) { - auto srcType = shapeCastOp.getSourceVectorType(); - auto dstType = shapeCastOp.getResultVectorType(); - auto srcShape = srcType.getShape(); - auto dstShape = dstType.getShape(); - if (srcShape.size() < dstShape.size()) { - shapeCastSourceAxis(srcShape, dstShape, dataAxis); - } else { - shapeCastSourceAxis(dstShape, srcShape, dataAxis); - } - return; - }) - .Case([&](vector::BroadcastOp broadcastOp) { - auto srcType = broadcastOp.getSourceType(); - auto dstType = broadcastOp.getResultVectorType(); - if (isScalar(srcType)) { - dataAxis.emplace_back(0); - } else { - auto inputType = mlir::cast(srcType); - auto outputType = mlir::cast(dstType); - getSrcBroadcastDim(inputType, outputType, dataAxis); - } - return; - }) - .Case([&](vector::TransposeOp transposeOp) { - auto perm = transposeOp.getPermutation(); - int start = 0; - for (auto x : perm) { - if (x != start) { - dataAxis.emplace_back(x); - } - start++; - } - return; - }) - .Default([&](Operation *op) { - // default is last axis - dataAxis.emplace_back( - cast(op->getResultTypes()[0]).getRank() - 1); - return; - }); -} - -static inline bool hasSameAxis(ArrayRef dims1, - ArrayRef dims2) { - DenseSet checkSet(dims2.begin(), dims2.end()); - return llvm::any_of(dims1, - [&checkSet](int64_t x) { return checkSet.contains(x); }); -} - -/// whether two operation has data dependency -/// op1 default is previous operation, op2 default is current operation -bool hasDataDependency(Operation *op1, Operation *op2) { - if (!isSpecialOp(op1) and !isSpecialOp(op2)) - return false; - - // TODO: Remove this condition to support special operation fusion in the - // future - if (disableSpecialOp) - return true; - - if (isReadOrWriteOperation(op1) or isReadOrWriteOperation(op2)) { - // if op1 is read the value and pass it to op2, it is not data dependency - if (isOperationsHasDefUseRelation(op1, op2)) - return false; - } - - // broadcast only fuse with post-op - if (isa(op2)) - return true; - - if (isa(op1) and disableBroadcastOp) - return true; - - // only special operation may cause data dependency - if (!isSpecialOp(op1)) - return hasDataDependency(op2, op1); - - auto res = - TypeSwitch(op1) - .Case([&](vector::ShapeCastOp shapeCastOp) { - SmallVector dims1, dims2; - getOperationDataAxis(op1, dims1); - getOperationDataAxis(op2, dims2); - if (!isSpecialOp(op2)) - return hasSameAxis(dims1, dims2); - - return true; - }) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - SmallVector dims2, reductionDims, parallelDims; - getOperationDataAxis(op1, reductionDims); - getOperationDataAxis(op2, dims2); - DenseSet checkSet(dims2.begin(), dims2.end()); - auto op2VectorType = getOperationVectorType(op2); - if (!isSpecialOp(op2)) { - // all reduction axis should be op2's data axis - bool reduceDependent = false; - for (auto x : reductionDims) { - if (!checkSet.contains(x)) { - reduceDependent = true; - break; - } - } - if (!reduceDependent) - return false; - - // all parallel axis should equal to op2's axis - checkSet.clear(); - checkSet.insert(reductionDims.begin(), reductionDims.end()); - auto rdRank = - multiReductionOp.getSourceVectorType().getRank(); - for (auto i = 0; i < rdRank; i++) - if (not checkSet.contains(i)) - parallelDims.emplace_back(i); - - checkSet.clear(); - checkSet.insert(parallelDims.begin(), parallelDims.end()); - auto rank = op2VectorType->getRank(); - for (auto i = 0; i < rank; i++) - if (!checkSet.contains(i)) - return true; - - return false; - } - - return true; - }) - .Case([&](vector::BroadcastOp broadcastOp) { - if (isSpecialOp(op2)) - return true; - - return !OpTrait::util::staticallyKnownBroadcastable( - getOperationVectorType(op1, false)->getShape(), - getOperationVectorType(op2)->getShape()); - }) - .Case( - [&](vector::TransposeOp transposeOp) { return true; }) - .Default([&](Operation *op) { return false; }); - - return res; -} - -/// Get the operation which is not a read-write in current queue -/// \param [in, out] op -Operation *getNotReadWriteOperaiton(std::queue &tmpQ) { - Operation *op = nullptr; - while (!tmpQ.empty()) { - Operation *cur = tmpQ.front(); - tmpQ.pop(); - if (isReadOrWriteOperation(cur)) - continue; - - op = cur; - } - return op; -} - -bool VectorFusionStrategy::isNeedNewGroup(Operation *op) { - if (isa(op)) { - notNeedToJudgeOps.push(op); - return false; - } - // 1. check previous operation - if (!opGroups.back().empty()) { - // We only care about the calculation operation. - std::queue tmpQ(opGroups.back()); - Operation *prevOp = nullptr; - prevOp = getNotReadWriteOperaiton(tmpQ); - if (!prevOp) { - // if previous operation is not in the same block, we need to create a - // group - return opGroups.back().back()->getParentOp() != op->getParentOp() or - isSpecialOp(op); - } - - if (prevOp->getParentOp() != op->getParentOp()) - return true; - - // special operation need to check data dependency axis - if (hasDataDependency(prevOp, op)) - return true; - - // previous operation vector type is not compatible with current operation - if (!isCompatibleVectorType(prevOp, op)) - return true; - } - return false; -} - -void VectorFusionStrategy::updateGroupBigestVectorType(VectorType vectorType) { - int64_t rank = vectorType.getRank(); - llvm::SmallDenseMap &groupVectorType = - getGroupBiggestRankVectorType(); - - if (groupVectorType.contains(opGroups.size() - 1)) { - VectorType bigestType = groupVectorType[opGroups.size() - 1]; - if (bigestType.getRank() < rank) - groupVectorType[opGroups.size() - 1] = vectorType; - - return; - } - - groupVectorType[opGroups.size() - 1] = vectorType; -} - -void VectorFusionStrategy::addOperationToGroup(Operation *op) { - assert(op); - VectorType vectorType = getOperationMaxVectorType(op).value(); - if (isNeedNewGroup(op)) - opGroups.emplace_back(std::queue()); - - if (not isa(op)) { - updateGroupBigestVectorType(vectorType); - while (not notNeedToJudgeOps.empty()) { - auto cur = notNeedToJudgeOps.front(); - notNeedToJudgeOps.pop(); - opGroupIndexMap[cur] = opGroups.size() - 1; - opGroups.back().push(cur); - } - opGroups.back().push(op); - opGroupIndexMap[op] = opGroups.size() - 1; - } - opAnchorPos[op] = getOperationMaxVectorType(op)->getRank() - 1; -} - -// We classify the operations we are interested in after filtering. Operations -// of in the same group have no data dependencies. Those operations can generate -// a same outter for loop. -void VectorFusionStrategy::classifyOperations() { - // dummpy - if (opGroups.empty()) - opGroups.emplace_back(std::queue()); - - func->walk([&](Operation *op) { - if (filterOperation(op)) { - addOperationToGroup(op); - return WalkResult::advance(); - } - if (isNotNeedToProcessOp(op) and !opGroups.back().empty()) - opGroups.emplace_back(std::queue()); - - return WalkResult::advance(); - }); -} - Value setOutGroupOperationOperandResult(Operation *op, const VectorType &newOperandType) { auto ret = @@ -3135,11 +2499,13 @@ void setOperationOperandResult(Operation *op, const VectorType &newOperandType, void ForLoopGenerator::createNewConstantOp( Operation *srcOp, vector::TransferWriteOp *transferWriteOp, size_t groupSteps) { - DenseMap &opPermuationMap = getOpPermuationMap(); + DenseMap &opPermuationMap = + getVectorBasedFusion().getOpPermuationMap(); IRRewriter srcWriter(srcOp); VectorType newOperandType = - getVectorzedType(cast(srcOp), groupSteps); + getVectorBasedFusion().getTypeHelper().getVectorzedType( + cast(srcOp), groupSteps); auto srcConstantOp = dyn_cast(srcOp); Operation *newConstantOp; if (isa(srcConstantOp.getValue())) { @@ -3175,18 +2541,20 @@ void ForLoopGenerator::createNewConstantOp( void ForLoopGenerator::rewriteOperationAsVectorize( OpBuilder &rewriter, size_t groupId, const std::queue *queue) { const std::queue groupOps = - !queue ? getFusionStrategy().getOpGroups()[groupId] : *queue; + !queue ? getVectorBasedFusion().getOpGroups()[groupId] : *queue; const DenseMap &opMap = - getFusionStrategy().getOpGroupIndexMap(); - DenseMap &opPermuationMap = getOpPermuationMap(); + getVectorBasedFusion().getOpGroupIndexMap(); + DenseMap &opPermuationMap = + getVectorBasedFusion().getOpPermuationMap(); std::queue transformQueue(groupOps); - size_t groupSteps = getFusionStrategy().getGroupMaxSteps()[groupId]; + size_t groupSteps = getVectorBasedFusion().getGroupMaxSteps()[groupId]; while (!transformQueue.empty()) { Operation *op = transformQueue.front(); transformQueue.pop(); - VectorType newOperandType = getVectorzedType(op, groupSteps); + VectorType newOperandType = + getVectorBasedFusion().getTypeHelper().getVectorzedType(op, groupSteps); auto lowerResult = TypeSwitch(op) .Case( @@ -3270,9 +2638,11 @@ mlir::FailureOr getOperationOperateTensor(Operation *op) { }); } -void CanonicalizerCommonUsedData::removeOpInCurrentGroups( - size_t grpIdx, Operation *op, Operation *replacedOp) { - std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[grpIdx]); +void GroupOperationFusionImpl::removeOpInCurrentGroups(size_t grpIdx, + Operation *op, + Operation *replacedOp) { + std::queue tmpOpQueue( + getGroupOperationFusion().getOpGroups()[grpIdx]); std::queue newOpQueue; while (!tmpOpQueue.empty()) { auto curOp = tmpOpQueue.front(); @@ -3281,31 +2651,32 @@ void CanonicalizerCommonUsedData::removeOpInCurrentGroups( newOpQueue.push(curOp); continue; } - getFusionStrategy().getOpGroupIndexMap().erase(curOp); - getFusionStrategy().getOpAnchorPos().erase(curOp); + getGroupOperationFusion().getOpGroupIndexMap().erase(curOp); + getGroupOperationFusion().getOpAnchorPos().erase(curOp); } - getFusionStrategy().getOpGroups()[grpIdx] = newOpQueue; + getGroupOperationFusion().getOpGroups()[grpIdx] = newOpQueue; // erase and replace the operation SmallVector usesOp(op->getUsers().begin(), op->getUsers().end()); IRRewriter rewriter(op); rewriter.replaceOp(op, replacedOp); // update removed operation related operation anchor position - getFusionStrategy().getOpAnchorPos()[replacedOp] = + getGroupOperationFusion().getOpAnchorPos()[replacedOp] = getOperationMaxVectorType(replacedOp)->getRank() - 1; for (Operation *x : usesOp) - getFusionStrategy().getOpAnchorPos()[x] = + getGroupOperationFusion().getOpAnchorPos()[x] = getOperationMaxVectorType(x)->getRank() - 1; updateOpGroupInfo(grpIdx); } -void CanonicalizerCommonUsedData::updateOpGroupInfo(size_t grpIdx) { - std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[grpIdx]); +void GroupOperationFusionImpl::updateOpGroupInfo(size_t grpIdx) { + std::queue tmpOpQueue( + getGroupOperationFusion().getOpGroups()[grpIdx]); // dummy init VectorType currentMaxRankType = getOperationMaxVectorType(tmpOpQueue.front()).value(); - getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx] = + getGroupOperationFusion().getGroupBiggestRankVectorType()[grpIdx] = currentMaxRankType; while (!tmpOpQueue.empty()) { @@ -3313,13 +2684,14 @@ void CanonicalizerCommonUsedData::updateOpGroupInfo(size_t grpIdx) { tmpOpQueue.pop(); VectorType type = getOperationMaxVectorType(curOp).value(); if (type.getRank() > currentMaxRankType.getRank()) - getFusionStrategy().getGroupBiggestRankVectorType()[grpIdx] = type; + getGroupOperationFusion().getGroupBiggestRankVectorType()[grpIdx] = type; } } -void CanonicalizerCommonUsedData::updateOpOperandResultInGroups( +void GroupOperationFusionImpl::updateOpOperandResultInGroups( size_t opGid, Operation *op, const Value &init, const Value &result) { - std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[opGid]); + std::queue tmpOpQueue( + getGroupOperationFusion().getOpGroups()[opGid]); std::queue newOpQueue; while (!tmpOpQueue.empty()) { auto curOp = tmpOpQueue.front(); @@ -3332,34 +2704,35 @@ void CanonicalizerCommonUsedData::updateOpOperandResultInGroups( if (!failed(getOperationVectorType(init.getDefiningOp()))) { newOpQueue.push(init.getDefiningOp()); - getFusionStrategy().getOpGroupIndexMap()[init.getDefiningOp()] = opGid; - getFusionStrategy().getOpAnchorPos()[init.getDefiningOp()] = - getFusionStrategy().getOpAnchorPos()[op]; + getGroupOperationFusion().getOpGroupIndexMap()[init.getDefiningOp()] = + opGid; + getGroupOperationFusion().getOpAnchorPos()[init.getDefiningOp()] = + getGroupOperationFusion().getOpAnchorPos()[op]; } newOpQueue.push(op); if (result && !failed(getOperationVectorType(result.getDefiningOp()))) { newOpQueue.push(result.getDefiningOp()); - getFusionStrategy().getOpGroupIndexMap()[result.getDefiningOp()] = opGid; - getFusionStrategy().getOpAnchorPos()[result.getDefiningOp()] = - getFusionStrategy().getOpGroupIndexMap()[op]; + getGroupOperationFusion().getOpGroupIndexMap()[result.getDefiningOp()] = + opGid; + getGroupOperationFusion().getOpAnchorPos()[result.getDefiningOp()] = + getGroupOperationFusion().getOpGroupIndexMap()[op]; } } - getFusionStrategy().getOpGroups()[opGid] = newOpQueue; + getGroupOperationFusion().getOpGroups()[opGid] = newOpQueue; } -void VectorFusionStrategy::run() { classifyOperations(); } - -void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( +void GroupOperationFusionImpl::generateEmptyTensorAndWrite( Operation *sourceOp, DenseMap> &srcOpCanoniclizedMap, size_t anchorPos, ReturnTypeKind retKind, DenseMap &visitedOperation) { DenseMap &opGroupIndexMap = - getFusionStrategy().getOpGroupIndexMap(); - SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); + getGroupOperationFusion().getOpGroupIndexMap(); + SmallVector, 8> &groupOpInitArgs = + getGroupOperationFusion().getGroupOpInitArgs(); SmallVector>, 8> - &groupOpResults = getGroupOpResults(); + &groupOpResults = getGroupOperationFusion().getGroupOpResults(); size_t sourceOpGid = opGroupIndexMap[sourceOp]; auto [tsr, writeOpresult] = @@ -3370,77 +2743,16 @@ void CanonicalizerCommonUsedData::generateEmptyTensorAndWrite( groupOpInitArgs[sourceOpGid].insert(tsr); groupOpResults[sourceOpGid].insert({writeOpresult, {retKind, anchorPos}}); // write opeartion anchor pos is same with current operation - getFusionStrategy().getOpAnchorPos()[writeOp] = + getGroupOperationFusion().getOpAnchorPos()[writeOp] = writeOp.getVectorType().getRank() - 1; - getOpPermuationMap()[writeOp] = writeOp.getPermutationMap(); -} - -template -Operation *CanonicalizerCommonUsedData::getNextTargetOperationInCurrentGroup( - Operation *curOp, const size_t grpIdx) { - std::queue tmpOpQueue(getFusionStrategy().getOpGroups()[grpIdx]); - if (isa(curOp)) - return curOp; - - while (!tmpOpQueue.empty()) { - auto frontOp = tmpOpQueue.front(); - if (isa(frontOp)) { - for (auto x : frontOp->getOperands()) - if (x.getDefiningOp() == curOp) - return frontOp; - } - tmpOpQueue.pop(); - } - return nullptr; + getGroupOperationFusion().getOpPermuationMap()[writeOp] = + writeOp.getPermutationMap(); } -void VectorOperationAnalyzer::analysisEmptyGroup() { - SmallVector, 8> &opGroups = - getFusionStrategy().getOpGroups(); - SmallVector>, 8> - &groupOpResults = getGroupOpResults(); - for (auto [idx, grp] : llvm::enumerate(opGroups)) { - if (grp.empty()) - continue; - if (groupOpResults[idx].empty()) - std::queue().swap(grp); - } -} - -void VectorOperationAnalyzer::analysisGroupMaxSteps() { - auto &opGroups = getFusionStrategy().getOpGroups(); - - for (auto [idx, grp] : llvm::enumerate(opGroups)) { - - uint32_t steps = std::numeric_limits::max(); - - llvm::SmallVector &grpSteps = - getFusionStrategy().getGroupMaxSteps(); - while (idx + 1 > grpSteps.size()) - grpSteps.emplace_back(steps); - - std::queue tmpQueue(grp); - auto calculateOpSteps = [&](Type type) { - auto opType = dyn_cast(type); - if (opType) - steps = std::min(steps, (uint32_t)getDataTypeValidSteps(opType)); - }; - while (!tmpQueue.empty()) { - auto op = tmpQueue.front(); - tmpQueue.pop(); - if (isa(op)) - calculateOpSteps(op->getOperandTypes()[0]); - - calculateOpSteps(getOperationVectorType(op).value()); - } - grpSteps[idx] = steps; - } -} - -void VectorOperationAnalyzer::specialOperationRectify( +void GroupOperationFusionImpl::specialOperationRectify( DenseMap &visitedOperation) { - auto &opGroups = getFusionStrategy().getOpGroups(); - IRRewriter rewriter(func); + auto &opGroups = getGroupOperationFusion().getOpGroups(); + IRRewriter rewriter(getGroupOperationFusion().getFunction()); for (auto [idx, grp] : llvm::enumerate(opGroups)) { std::queue tmpQueue(grp); @@ -3449,14 +2761,14 @@ void VectorOperationAnalyzer::specialOperationRectify( auto op = tmpQueue.front(); tmpQueue.pop(); // remain transfer read operation to do the broadcast fusion - if (isa(op) and not disableBroadcastOp) { + if (isa(op)) { auto srcOp = op->getOperand(0).getDefiningOp(); assert(isa(srcOp)); // only have write operation, otherwise the group size will bigger // than 1. Because the last operation is always a write operation in // each group - getFusionStrategy().getOpAnchorPos()[srcOp] = - getFusionStrategy().getOpAnchorPos()[op]; + getGroupOperationFusion().getOpAnchorPos()[srcOp] = + getGroupOperationFusion().getOpAnchorPos()[op]; rewriter.replaceOp(op, srcOp); continue; @@ -3464,22 +2776,22 @@ void VectorOperationAnalyzer::specialOperationRectify( // anchor of multidim reduction rectify if (isa(op)) { auto accSourceOp = op->getOperand(1).getDefiningOp(); - getFusionStrategy().getOpAnchorPos()[accSourceOp] = + getGroupOperationFusion().getOpAnchorPos()[accSourceOp] = getOperationVectorType(accSourceOp)->getRank() - 1; } newQueue.push(op); } - getFusionStrategy().getOpGroups()[idx] = newQueue; + getGroupOperationFusion().getOpGroups()[idx] = newQueue; } } -void VectorOperationAnalyzer::updateReturnResultKind(Operation *sourceOp, - size_t sourceOpGid, - ReturnTypeKind rtKind) { +void GroupOperationFusionImpl::updateReturnResultKind(Operation *sourceOp, + size_t sourceOpGid, + ReturnTypeKind rtKind) { SmallVector>, 8> - &groupOpResults = getGroupOpResults(); + &groupOpResults = getGroupOperationFusion().getGroupOpResults(); DenseMap &OpAnchorPos = - getFusionStrategy().getOpAnchorPos(); + getGroupOperationFusion().getOpAnchorPos(); Value sourceResult = sourceOp->getResults()[0]; if (srcOpCanoniclizedMap.contains(sourceOp)) @@ -3498,11 +2810,11 @@ void VectorOperationAnalyzer::updateReturnResultKind(Operation *sourceOp, std::make_pair(rtKind, srcOpAnchor); } -void VectorOperationAnalyzer::replaceConstantOpAsNewOp(Operation *op, - Operation *sourceOp, - size_t operandIdx) { +void GroupOperationFusionImpl::replaceConstantOpAsNewOp(Operation *op, + Operation *sourceOp, + size_t operandIdx) { DenseMap &opGroupIndexMap = - getFusionStrategy().getOpGroupIndexMap(); + getGroupOperationFusion().getOpGroupIndexMap(); if (!opGroupIndexMap.contains(op)) { return; } @@ -3527,10 +2839,12 @@ void VectorOperationAnalyzer::replaceConstantOpAsNewOp(Operation *op, auto constantOp = cast(sourceOp); IRRewriter rewriter(constantOp); size_t groupSteps = - getFusionStrategy().getGroupMaxSteps()[opGroupIndexMap[op]]; + getGroupOperationFusion().getGroupMaxSteps()[opGroupIndexMap[op]]; if (isa(constantOp.getValue())) { - VectorType newOperandType = getVectorzedType(op, groupSteps); + VectorType newOperandType = + getGroupOperationFusion().getTypeHelper().getVectorzedType(op, + groupSteps); auto valueType = cast(constantOp.getValue()); if (valueType.isSplat()) { FailureOr res = createArithSplatConstantOp( @@ -3550,17 +2864,19 @@ void VectorOperationAnalyzer::replaceConstantOpAsNewOp(Operation *op, } } -void VectorOperationAnalyzer::makeSourceOpWriteResultToTensor( +void GroupOperationFusionImpl::makeSourceOpWriteResultToTensor( Operation *sourceOp, size_t sourceOpGid, ReturnTypeKind rtKind) { DenseMap &OpAnchorPos = - getFusionStrategy().getOpAnchorPos(); - SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); + getGroupOperationFusion().getOpAnchorPos(); + SmallVector, 8> &groupOpInitArgs = + getGroupOperationFusion().getGroupOpInitArgs(); if (!srcOpCanoniclizedMap.contains(sourceOp)) { // get write operation if (Operation *writeOp = - getNextTargetOperationInCurrentGroup( - sourceOp, sourceOpGid)) { + getGroupOperationFusion() + .getNextTargetOperationInCurrentGroup( + sourceOp, sourceOpGid)) { auto writeOpresult = writeOp->getResults()[0]; auto writeTensor = writeOp->getOperands()[1]; // find original tensor.empty operation @@ -3580,15 +2896,16 @@ void VectorOperationAnalyzer::makeSourceOpWriteResultToTensor( sourceOpGid, rtKind); } -void VectorOperationAnalyzer::groupOperationNeedReturnResult( +void GroupOperationFusionImpl::GroupOperationReturnResultProcess( size_t sourceOpGid, Operation *sourceOp, Operation *op, size_t operandIdx, bool inSameGroupNeedReturn) { ReturnTypeKind rtKind = inSameGroupNeedReturn ? ReturnTypeKind::RT_InGroup : ReturnTypeKind::RT_OutGroup; - SmallVector, 8> &groupOpInitArgs = getGroupOpInitArgs(); + SmallVector, 8> &groupOpInitArgs = + getGroupOperationFusion().getGroupOpInitArgs(); DenseMap &opGroupIndexMap = - getFusionStrategy().getOpGroupIndexMap(); + getGroupOperationFusion().getOpGroupIndexMap(); // update init iterargs auto dstRet = getOperationOperateTensor(sourceOp); // need to generate tensor.emtpy and vector.transfer_write, write @@ -3620,16 +2937,17 @@ void VectorOperationAnalyzer::groupOperationNeedReturnResult( updateReturnResultKind(sourceOp, sourceOpGid, rtKind); } -void VectorOperationAnalyzer::analysisGroupOperaion() { +void GroupOperationFusionImpl::canonicalizeEachOperationGroup() { // record the operation which has been moved DenseSet movedOperationSet; // record the operation's visited order, inorder to ensure set // correct operand size_t opCounter = 0; DenseMap &opGroupIndexMap = - getFusionStrategy().getOpGroupIndexMap(); + getGroupOperationFusion().getOpGroupIndexMap(); DenseMap &OpAnchorPos = - getFusionStrategy().getOpAnchorPos(); + getGroupOperationFusion().getOpAnchorPos(); + func::FuncOp func = getGroupOperationFusion().getFunction(); IRRewriter rewriter(func); analysisGroupMaxSteps(); @@ -3650,8 +2968,8 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { !notInSameGroup and OpAnchorPos[sourceOp] > OpAnchorPos[op]; if (notInSameGroup or outOfGroup or inSameGroupNeedReturn) - groupOperationNeedReturnResult(sourceOpGid, sourceOp, op, idx, - inSameGroupNeedReturn); + GroupOperationReturnResultProcess(sourceOpGid, sourceOp, op, idx, + inSameGroupNeedReturn); continue; } @@ -3667,10 +2985,11 @@ void VectorOperationAnalyzer::analysisGroupOperaion() { void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, Value originalResult, Value forResult) { - size_t totalGroupSize = getFusionStrategy().getOpGroups().size(); + size_t totalGroupSize = getVectorBasedFusion().getOpGroups().size(); size_t startGroup = currentGroupId; while (startGroup < totalGroupSize) { - SetVector &operandVector = getGroupOpInitArgs()[startGroup++]; + SetVector &operandVector = + getVectorBasedFusion().getGroupOpInitArgs()[startGroup++]; if (not operandVector.contains(originalResult)) continue; SetVector replacedVector; @@ -3682,7 +3001,8 @@ void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, } replacedVector.insert(v); } - getGroupOpInitArgs()[startGroup - 1] = replacedVector; + getVectorBasedFusion().getGroupOpInitArgs()[startGroup - 1] = + replacedVector; } } @@ -3703,8 +3023,7 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( return forOp; } -bool CanonicalizerCommonUsedData::isGroupHasSpecialOperation( - const size_t grpIdx) { +bool LoopGeneratorImpl::isGroupHasSpecialOperation(const size_t grpIdx) { auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; auto &bcCanonicalizer = getBroadcastCanonicalizers()[grpIdx]; auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; @@ -3715,8 +3034,8 @@ bool CanonicalizerCommonUsedData::isGroupHasSpecialOperation( !shapeCastCanonicalizer.getCandidateOps().empty(); } -void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { - auto &grp = getFusionStrategy().getOpGroups()[idx]; +void LoopGeneratorImpl::generateGroupOpVectorizedIR(const int idx) { + auto &grp = getVectorBasedFusion().getOpGroups()[idx]; if (grp.empty()) { LDBG("Current operation Group is empty."); return; @@ -3727,7 +3046,7 @@ void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { } VectorType groupType = - getFusionStrategy().getGroupBiggestRankVectorType()[idx]; + getVectorBasedFusion().getGroupBiggestRankVectorType()[idx]; IRRewriter rewriter(grp.back()); rewriter.setInsertionPointAfter(grp.back()); // 1. Rewrite operation as vectorized form @@ -3735,9 +3054,9 @@ void ForLoopGenerator::generateGroupOpVectorizedIR(const int idx) { // rewriteOperationAsVectorize(rewriter, idx); auto forOp = generateVectorizedForLoop(idx, rewriter, groupType); // special operation do not need to change anything - if (failed(forOp)) { + if (failed(forOp)) return; - } + moveLoopInvariantCode(forOp.value()); } @@ -3911,10 +3230,10 @@ struct CPUPhysicalRegisterPass HardWareInfo hwInfo; CPUTargetDescriptionAnalysis sysDesc = getAnalysis(); - hwInfo.favx512f = sysDesc.getMaxVectorWidth() == 512; + hwInfo.favx512f = sysDesc.getMaxVectorWidth() >= 512; hwInfo.favx2 = sysDesc.getMaxVectorWidth() >= 256; - CanonicalizerVectorOperation canonicalizer( - func, CanonicalizerKind::OperationsGroup, hwInfo); + VectorOperationCanonicalizer canonicalizer( + func, hwInfo, CanonicalizerKind::GroupOperations); canonicalizer.run(); candidateFunc = isReadOrWriteOperation; @@ -3930,7 +3249,6 @@ struct CPUPhysicalRegisterPass (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; -} // namespace } // namespace gc } // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index b4a46d231..0c9bb6322 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -77,6 +77,7 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // scf + arith + math + vector + tensor + linalg.brgemm void populateVectorPasses(mlir::OpPassManager &pm) { + pm.addNestedPass(createLowerToTileVector()); // Do promotion for math / arith ops pm.addNestedPass(math::createMathLegalizeToF32()); // sourceTypeStrs can be extended @@ -88,8 +89,7 @@ void populateVectorPasses(mlir::OpPassManager &pm) { arith::createArithEmulateUnsupportedFloats(options)); // Bf16 cast elimilation pass pm.addNestedPass(mlir::createCanonicalizerPass()); - // oneDNN graph spec - pm.addNestedPass(arith::createArithExpandOpsPass()); + pm.addNestedPass(createCPUPhysicalRegisterPass()); // todo: lower to physical vector pass, device dependent pass populateCleanUpPasses(pm); } diff --git a/lib/gc/Transforms/TilingVector.h b/lib/gc/Transforms/TilingVector.hpp similarity index 71% rename from lib/gc/Transforms/TilingVector.h rename to lib/gc/Transforms/TilingVector.hpp index 9ad3d074a..0545325e6 100644 --- a/lib/gc/Transforms/TilingVector.h +++ b/lib/gc/Transforms/TilingVector.hpp @@ -8,10 +8,8 @@ #ifndef GC_PASSES_TILINGVECTOR_H #define GC_PASSES_TILINGVECTOR_H +#include "gc/Analysis//VectorBasedFusionAnalysis.h" #include "gc/Analysis/TargetDescriptionAnalysis.h" -#include "gc/Dialect/Linalgx/LinalgxOps.h" -#include "gc/Dialect/Microkernel/MicrokernelOps.h" -#include "gc/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -19,23 +17,17 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" -#include "mlir/Dialect/Traits.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/ExecutionEngine/Float16bits.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Visitors.h" #include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" -#include + namespace mlir { namespace gc { -namespace { //===----------------------------------------------------------------------===// // helper function @@ -47,11 +39,22 @@ Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); /// get operation read or write tensor mlir::FailureOr getOperationOperateTensor(Operation *op); -/// record hardware information -struct HardWareInfo { - bool favx512f = true; - bool favx2 = true; -}; +/// set correct operand for the operation +void setOperationCorrectOperand( + Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, + DenseMap &originalOperandLoopArgsMap, + ArrayRef inductionVars, + DenseMap &opPermuationMap); + +/// get fusion kind +/// Has two kind: +/// 1. OperationGroup: +/// The operation is converted into physical registers through our fusion +/// strategy. +/// 2. Operations:(TODO:) +/// The user ensures that there is no data dependency between operations, +/// and we directly convert the operations into physical register sizes. +enum CanonicalizerKind { GroupOperations, Operations }; /// To avoid too many parameters in function when generate for loop struct GenerateLoopHelper { @@ -132,139 +135,12 @@ struct GenerateLoopHelper { DenseMap &argsOriginalMap); }; -/// set correct operand for the operation -void setOperationCorrectOperand( - Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, - DenseMap &originalOperandLoopArgsMap, - ArrayRef inductionVars, - DenseMap &opPermuationMap); - //===----------------------------------------------------------------------===// // vectorize operation class //===----------------------------------------------------------------------===// -/// Vector type conversion helper class -class TypeHelper { -private: - HardWareInfo HWInfo; - -public: - /// use \param info to set hardware information - void setHardWareInfo(HardWareInfo &info) { HWInfo = info; } - /// get vector \param type max loop step according to hardware information - int getDataTypeValidSteps(VectorType type); - /// get vector \param type an even for loop step - int generateValidSteps(int steps, VectorType type); - /// get vector \param type max simd length according to hardware information - int getDataTypeMAXSIMDLength(VectorType type); - /// get operation's vector type - VectorType getVectorzedType(Operation *op, uint32_t loopStep = 0); -}; - -/// Operation fusion strategy class. -/// 1. Classify operaions: -/// classify the operations into : -/// a. reorder, transpose. Reorder(or transpose) dim may bring data -/// dependency. -/// b. elemenwise. Those operations can be fused into a common for loop. -/// c. broadcast. Need to analysis broadcast dim and the data -/// dependency. -/// d. reduction. Need to analysis broadcast dim and the -/// data dependency. -/// Same group operations have no data dependencies. They can be fused into a -/// common for loop body. - -/// Using queue to store the operation order. In order to ensure that -/// subsequent moves to the operation will not cause semantic changes. -class VectorFusionStrategy : public TypeHelper { -private: - func::FuncOp func; - SmallVector, 8> opGroups; - SmallVector groupMaxSteps; - /// vector type which has bigest rank in current operation group - llvm::SmallDenseMap groupBigestRankVectorType; - /// query current operation in which group, return group index - DenseMap opGroupIndexMap; - /// can fused into prev operation which axis position - DenseMap opAnchorPos; - /// record some operations which not need to No need to judge whether can be - /// fused - std::queue notNeedToJudgeOps; - -public: - VectorFusionStrategy() = default; - VectorFusionStrategy(func::FuncOp &func) : func(func) {} - VectorFusionStrategy(func::FuncOp &func, TypeHelper &typeHelper) - : TypeHelper(typeHelper), func(func) {} - - VectorFusionStrategy(VectorFusionStrategy &strategy) - : func(strategy.func), opGroups(strategy.opGroups), - groupMaxSteps(strategy.groupMaxSteps), - opGroupIndexMap(strategy.opGroupIndexMap), - opAnchorPos(strategy.opAnchorPos) {}; - - VectorFusionStrategy(VectorFusionStrategy &&strategy) - : func(std::move(strategy.func)), opGroups(std::move(strategy.opGroups)), - groupMaxSteps(std::move(strategy.groupMaxSteps)), - opGroupIndexMap(std::move(strategy.opGroupIndexMap)), - opAnchorPos(std::move(strategy.opAnchorPos)) {}; - - VectorFusionStrategy &operator=(VectorFusionStrategy &&) = default; - - /// Get the map which contains each group vector type which has biggest rank. - llvm::SmallDenseMap & - getGroupBiggestRankVectorType() noexcept { - return groupBigestRankVectorType; - }; - /// Get the operation group obtained by fusion strategy analysis - SmallVector, 8> &getOpGroups() noexcept { - return opGroups; - } - /// Get the operation belong to which group index map - DenseMap &getOpGroupIndexMap() noexcept { - return opGroupIndexMap; - } - /// Get the map contains max steps of each group - llvm::SmallVector &getGroupMaxSteps() noexcept { - return groupMaxSteps; - } - /// Get the map contains anchor position of each operation - llvm::DenseMap &getOpAnchorPos() noexcept { - return opAnchorPos; - } - /// Get current function IR - func::FuncOp &getFunc() { return func; } - /// Do fusion strategy - void classifyOperations(); - - /// Whether two operations have compatible vector shapes - bool isCompatibleVectorType(Operation *op1, Operation *op2); - - /// update bigest vector type for last operation group - void updateGroupBigestVectorType(VectorType vectorType); - - /// Check whether the operation can fuse with previous operation - bool isNeedNewGroup(Operation *op); - - /// Add Operation \p op into current last group or a new Group - /// \p op must has valid value, can't be nullptr - void addOperationToGroup(Operation *op); - - /// run the vector-based fusion strategy - void run(); -}; - -/// Has two kind: -/// 1. OperationGroup: -/// The operation is converted into physical registers through our fusion -/// strategy. -/// 2. Operations:(TODO:) -/// The user ensures that there is no data dependency between operations, -/// and we directly convert the operations into physical register sizes. -enum CanonicalizerKind { OperationsGroup, Operations }; - /// base class of special operation -template class SpecialOperationCanonicalizer : virtual TypeHelper { +template class SpecialOperationCanonicalizer { private: /// store current special operation SmallVector candidateRdOps; @@ -396,7 +272,7 @@ class BroadcastCanonicalizer private: public: BroadcastCanonicalizer( - const llvm::SmallVector &candidateBcOps, + const SmallVector &candidateBcOps, size_t steps = 1) : SpecialOperationCanonicalizer( candidateBcOps, SpecialOperationKind::OP_Broadcast, steps) {}; @@ -457,172 +333,22 @@ class ShapeCastCanonicalizer bool isReadWriteOnLastDim(); }; -/// operation return kind, which is used to determine whether the operation need -/// to return it's result in current for loop -enum class ReturnTypeKind { - RT_Both, - RT_OutGroup, - RT_InGroup, -}; - -class CanonicalizerCommonUsedData : public TypeHelper { -private: - VectorFusionStrategy fusionStrategy; - -private: - /// analysis the operation's operands and results - SmallVector>, 8> - groupOpResults; - /// store loop iteration args for each of operation group - SmallVector, 8> groupOpInitArgs; - - // store read and write operations permutation maps in order to convenient - // to replace loop induction var - DenseMap opPermuationMap; - SmallVector multiRdCanonicalizers; - SmallVector broadcastCanonicalizers; - SmallVector transposeCanonicalizers; - SmallVector shapeCastCanonicalizers; - -public: - CanonicalizerCommonUsedData() = default; - CanonicalizerCommonUsedData(VectorFusionStrategy &fusionStrategy) - : fusionStrategy(fusionStrategy) {}; - - CanonicalizerCommonUsedData( - VectorFusionStrategy &fusionStrategy, - SmallVector>, 8> - &groupOpResults, - SmallVector, 8> &groupOpInitArgs, - DenseMap &opPermuationMap) - : fusionStrategy(fusionStrategy), groupOpResults(groupOpResults), - groupOpInitArgs(groupOpInitArgs), opPermuationMap(opPermuationMap) {} - virtual ~CanonicalizerCommonUsedData() noexcept {}; - - /// Set fusion strategy - void setFuseStrategy(VectorFusionStrategy &&strategy) { - fusionStrategy = std::move(strategy); - llvm::SmallVector, 8> &opGroups = - fusionStrategy.getOpGroups(); - // init operations results and initialization args - if (opGroups.size() != groupOpResults.size() || - opGroups.size() != groupOpInitArgs.size()) { - groupOpResults.clear(); - groupOpInitArgs.clear(); - for (size_t i = 0; i < opGroups.size(); i++) { - groupOpResults.emplace_back( - llvm::MapVector>()); - groupOpInitArgs.emplace_back(SetVector()); - } - } - } - - void setGroupOpResults( - const SmallVector< - llvm::MapVector>, 8> - &results) { - groupOpResults = std::move(results); - } - - void setGroupOpIterArgs( - const SmallVector, 8> &initArgs) noexcept { - groupOpInitArgs = std::move(initArgs); - } - - void setPermutationMap(const DenseMap &map) noexcept { - opPermuationMap = std::move(map); - } - - // get methods - VectorFusionStrategy &getFusionStrategy() noexcept { return fusionStrategy; } - - SmallVector>, 8> & - getGroupOpResults() noexcept { - return groupOpResults; - } - - SmallVector, 8> &getGroupOpInitArgs() noexcept { - return groupOpInitArgs; - } - - DenseMap &getOpPermuationMap() noexcept { - return opPermuationMap; - } - - SmallVector & - getMultiRdCanonicalizers() noexcept { - return multiRdCanonicalizers; - } - - SmallVector & - getBroadcastCanonicalizers() noexcept { - return broadcastCanonicalizers; - } - - SmallVector & - getTransposeCanonicalizers() noexcept { - return transposeCanonicalizers; - } - - SmallVector & - getShapeCastCanonicalizers() noexcept { - return shapeCastCanonicalizers; - } - - /// whether \param grpIdx operation group has special operation - bool isGroupHasSpecialOperation(const size_t grpIdx); - - /// make emtpy tensor and write the operation result to the tensor - void generateEmptyTensorAndWrite( - Operation *sourceOp, - llvm::DenseMap> - &srcOpCanoniclizedMap, - size_t anchorPos, ReturnTypeKind retKind, - DenseMap &visitedOperation); - - /// update \param opGid operation group - void updateOpOperandResultInGroups(size_t opGid, Operation *op, - const Value &init = Value(), - const Value &result = Value()); - /// replace \param op in \param grpIdx operation group with \param replacedOp - void removeOpInCurrentGroups(size_t grpIdx, Operation *op, - Operation *replacedOp); - /// update operation in grpIdx group related information - void updateOpGroupInfo(size_t grpIdx); - - /// make a transfer_read operation and read the producer operation result - Value - canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, - size_t operandIdx, - vector::TransferReadOp *srcReadOp = nullptr); - - /// make a transfer_read operation - Operation * - createTransferReadOpBefore(Operation *op, const Value &operand, - vector::TransferReadOp *srcReadOp = nullptr); - /// get next operation in current operation group - template - Operation *getNextTargetOperationInCurrentGroup(Operation *curOp, - const size_t grpIdx); -}; - /// generate for loop for each operation. -class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { +class ForLoopGenerator { private: - /// currrent function IR - func::FuncOp func; + GroupOperationFusion vectorBasedFusion; public: - ForLoopGenerator() = default; - ForLoopGenerator(func::FuncOp &func) : func(func) {} + ForLoopGenerator(GroupOperationFusion &fusion) : vectorBasedFusion(fusion) {} virtual ~ForLoopGenerator() noexcept {} - void setGeneratorFunc(func::FuncOp &func) noexcept { this->func = func; } + void setVectorBaseFusion(GroupOperationFusion &vectorBasedFusion) { + this->vectorBasedFusion = vectorBasedFusion; + }; + /// clear current group operation void clearCurrentOperationGroup(size_t grpIdx); - /// vectorize operations in current operation group - void generateGroupOpVectorizedIR(const int idx); /// prepare for loop iteration args void prepareForLoopArgs(const size_t grpIdx, GenerateLoopHelper &loopHelper); @@ -639,6 +365,11 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { DenseMap> &indiceLoopMap, const size_t groupId, Operation *op, const DenseMap &setIdxMap = DenseMap({})); + + // get methods + GroupOperationFusion &getVectorBasedFusion() noexcept { + return vectorBasedFusion; + } /// rewrite operation as vectorize IR in current operation group void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, @@ -654,7 +385,7 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { scf::ForOp constructNestedForOp(const size_t groupIdx, OpBuilder &b, const Location &loc, ArrayRef dims, GenerateLoopHelper &loopHelper); - + /// move operations in \param queue to current loop anchor void moveOperationsToCurrentForBody(const OpBuilder &b, std::queue &queue, GenerateLoopHelper &loopHelperParam); @@ -698,13 +429,80 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { void replaceOperationsWithForLoopResult( IRRewriter &rewrite, const std::queue &movingOperations, GenerateLoopHelper &loopHelperParam); + + /// rectify indice for transfer_write operation + /// e.g.: vector.transfer_write"(%16, %9, %c0, %c0), the first %c0 should use + /// original indice not create by us + void rectifyWriteOperationIndice(vector::TransferWriteOp *originalWriteOp, + SmallVectorImpl &writeVars); + /// rectify indice for transfer_read operation, like broadcast operation + /// fusion by transfer_read , but the transfer_read operation is in innermost + /// for loop body, we must set correct for loop var. e.g.: + /// vector.transfer_read"(%16, %9, %c0), the first %c0 should use correct for + /// innermost loop iter vars + void rectifyReadOperationIndice(vector::TransferReadOp *originalReadOp, + VectorType loopType, + ArrayRef inductionVars, + SmallVectorImpl &readVars); + + /// rectify each group operand use for loop result + void rectifyGroupOperands(size_t currentGroupId, Value originalResult, + Value forResult); +}; + +class LoopGeneratorImpl : public ForLoopGenerator { + +private: + SmallVector multiRdCanonicalizers; + SmallVector broadcastCanonicalizers; + SmallVector transposeCanonicalizers; + SmallVector shapeCastCanonicalizers; + +public: + LoopGeneratorImpl(GroupOperationFusion &fusion) : ForLoopGenerator(fusion) {}; + + virtual ~LoopGeneratorImpl() noexcept {}; + + SmallVector & + getMultiRdCanonicalizers() noexcept { + return multiRdCanonicalizers; + } + + SmallVector & + getBroadcastCanonicalizers() noexcept { + return broadcastCanonicalizers; + } + + SmallVector & + getTransposeCanonicalizers() noexcept { + return transposeCanonicalizers; + } + + SmallVector & + getShapeCastCanonicalizers() noexcept { + return shapeCastCanonicalizers; + } + /// clear special operation canonicalizer container + void clearSpecialOperationCanonicalizers(); + + /// add a dummy special canonicalizer + void dummyInitSpecialOperation(size_t steps); + + /// initialize all the speical operation canonicalizer + void initSpeicalOperationCanonicalizers(); + + /// generate for loop for current special operation use \param generateFunc + template + void processSpecialOperation( + T &canonicalizers, const std::function &generateFunc); + // Canonicalize special operation + void canonicalizeSpecialOperation(); + + /// whether \param grpIdx operation group has special operation + bool isGroupHasSpecialOperation(const size_t grpIdx); + // multireduction forloop methods scf::ForOp generateMultiReductionForLoop(const size_t grpIdx); - /// Rearrange the current opIR to facilitate the generation of the correct - /// reduction IR - void rearrageMultiReductionIR( - const size_t grpIdx, - DenseMap> &indiceLoopMap); /// reduction operation reduction axis for loop scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder, @@ -721,6 +519,12 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { DenseMap &nextAnchorArgsIdxMap, SmallVector &nextAnchorArgs); + /// Rearrange the current opIR to facilitate the generation of the correct + /// reduction IR + void rearrageMultiReductionIR( + const size_t grpIdx, + DenseMap> &indiceLoopMap); + /// generate for loop for transpose operation scf::ForOp generateTransposeForLoop(const size_t grpIdx); /// shuffle instruction optimize for transpose operation @@ -741,46 +545,29 @@ class ForLoopGenerator : virtual public CanonicalizerCommonUsedData { OpBuilder &opBuilder, const size_t grpIdx, const size_t forDimIdx, const size_t steps, const Location &loc, SmallVector &inductionVars, ValueRange iterArgs); - /// rectify indice for transfer_write operation - /// e.g.: vector.transfer_write"(%16, %9, %c0, %c0), the first %c0 should use - /// original indice not create by us - void rectifyWriteOperationIndice(vector::TransferWriteOp *originalWriteOp, - SmallVectorImpl &writeVars); - /// rectify indice for transfer_read operation, like broadcast operation - /// fusion by transfer_read , but the transfer_read operation is in innermost - /// for loop body, we must set correct for loop var. e.g.: - /// vector.transfer_read"(%16, %9, %c0), the first %c0 should use correct for - /// innermost loop iter vars - void rectifyReadOperationIndice(vector::TransferReadOp *originalReadOp, - VectorType loopType, - ArrayRef inductionVars, - SmallVectorImpl &readVars); - /// rectify each group operand use for loop result - void rectifyGroupOperands(size_t currentGroupId, Value originalResult, - Value forResult); + /// vectorize operations in current operation group + void generateGroupOpVectorizedIR(const int idx); }; -class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { +/// group operation fusion implementation class +class GroupOperationFusionImpl : public GroupOperationAnalysis { private: - func::FuncOp func; + /// In which tensor is the result of the source operation stored, and the + /// result of transfer_write. DenseMap> srcOpCanoniclizedMap; + /// have visited operations DenseMap visitedOperation; public: - virtual ~VectorOperationAnalyzer() = default; - VectorOperationAnalyzer() = default; - VectorOperationAnalyzer(func::FuncOp &func) : func(func) {} - - void setAnalysisFunc(func::FuncOp &func) { this->func = func; } - /// remove the useless operation, due to it result is not require by other - // operation - void analysisEmptyGroup(); - /// get each operation in each group maximum support vectorization length - void analysisGroupMaxSteps(); - /// analysis operation result of current group whether needed by other - /// operation - void analysisGroupOperaion(); + virtual ~GroupOperationFusionImpl() = default; + GroupOperationFusionImpl(func::FuncOp &func, HardWareInfo &info) + : GroupOperationAnalysis(func, info) {} + + /// Generate emtpy tensor and write operations for operations that need to + /// return their results, and generate read operations for operations that + /// need to read parameters from the block. + void canonicalizeEachOperationGroup(); void specialOperationRectify(DenseMap &visitedOperation); /// update operation result kind @@ -789,63 +576,63 @@ class VectorOperationAnalyzer : virtual public CanonicalizerCommonUsedData { /// process the operation which need to return result /// \param *op current operation - void groupOperationNeedReturnResult(size_t sourceOpGid, Operation *sourceOp, - Operation *op, size_t operandIdx, - bool inSameGroupNeedReturn); + void GroupOperationReturnResultProcess(size_t sourceOpGid, + Operation *sourceOp, Operation *op, + size_t operandIdx, + bool inSameGroupNeedReturn); /// source operation write it's result to a tensor void makeSourceOpWriteResultToTensor(Operation *sourceOp, size_t sourceOpGid, ReturnTypeKind rtKind); /// analysis constant operation and replace it with a new constant operation void replaceConstantOpAsNewOp(Operation *op, Operation *sourceOp, size_t operandIdx); + /// replace \param op in \param grpIdx operation group with \param replacedOp + void removeOpInCurrentGroups(size_t grpIdx, Operation *op, + Operation *replacedOp); + /// update operation in grpIdx group related information + void updateOpGroupInfo(size_t grpIdx); + /// make a transfer_read operation and read the producer operation result + Value + canonicalizeCurrentOperation(Operation *op, const Value &transferReadOperand, + size_t operandIdx, + vector::TransferReadOp *srcReadOp = nullptr); + /// update \param opGid operation group + void updateOpOperandResultInGroups(size_t opGid, Operation *op, + const Value &init = Value(), + const Value &result = Value()); + + /// make emtpy tensor and write the operation result to the tensor + void generateEmptyTensorAndWrite( + Operation *sourceOp, + llvm::DenseMap> + &srcOpCanoniclizedMap, + size_t anchorPos, ReturnTypeKind retKind, + DenseMap &visitedOperation); + + /// make a transfer_read operation + Operation * + createTransferReadOpBefore(Operation *op, const Value &operand, + vector::TransferReadOp *srcReadOp = nullptr); }; /// Vectorize vector operation with target machines max simd length. -class CanonicalizerVectorOperation : virtual public ForLoopGenerator, - VectorOperationAnalyzer { +class VectorOperationCanonicalizer { private: - /// current function IR + GroupOperationFusionImpl fusion; + LoopGeneratorImpl loopGenerator; + CanonicalizerKind kind; func::FuncOp func; - /// rewriter of func operation IRRewriter rewriter; - CanonicalizerKind kind; public: - CanonicalizerVectorOperation( - func::FuncOp func, - CanonicalizerKind kind = CanonicalizerKind::OperationsGroup, - HardWareInfo hwInfo = {}) - : func(func), rewriter(func), kind(kind) { - setAnalysisFunc(func); - setGeneratorFunc(func); - setHardWareInfo(hwInfo); - // vector operation fusion - if (kind == CanonicalizerKind::OperationsGroup) { - VectorFusionStrategy fusionStrategy(func); - fusionStrategy.run(); - setFuseStrategy(std::move(fusionStrategy)); - } - } - virtual ~CanonicalizerVectorOperation() = default; - - // get functions - func::FuncOp &getFunc() noexcept { return func; }; - IRRewriter &getIRWewriter() noexcept { return rewriter; } - /// generate for loop for current special operation use \param generateFunc - template - void processSpecialOperation( - T &canonicalizers, const std::function &generateFunc); - // Canonicalize special operation - void canonicalizeSpecialOperation(); - /// clear special operation canonicalizer container - void clearSpecialOperationCanonicalizers(); - /// add a dummy special canonicalizer - void dummyInitSpecialOperation(size_t steps); - /// initialize all the speical operation canonicalizer - void initSpeicalOperationCanonicalizers(); + VectorOperationCanonicalizer( + func::FuncOp &func, HardWareInfo &info, + CanonicalizerKind kind = CanonicalizerKind::GroupOperations) + : fusion(func, info), loopGenerator(fusion.getGroupOperationFusion()), + kind(kind), rewriter(func) {} + virtual ~VectorOperationCanonicalizer() = default; /// run the vector canonicalizer for the IR void run(); }; -} // namespace } // namespace gc } // namespace mlir #endif \ No newline at end of file diff --git a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir index ccf2c9676..fd380ad24 100644 --- a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir +++ b/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir @@ -599,7 +599,49 @@ func.func @reduce_fuse_test12(%input: tensor<16x32x64xf32>, func.return %1 : tensor<16x32xf32> } -// CHECK-LABEL: func @add_small_tensor_test13 +// CHECK-LABEL: func @reduce_fuse_test13 +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<16x32x64xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg3:.*]] = %[[EMPTY0]]) -> (tensor<16x32x64xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<16x32x64xf32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg7:.*]] = %[[arg5]]) -> (tensor<16x32x64xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %{{.*}}[%[[arg2]], %[[arg4]], %[[arg6]]], %[[CST_0]] {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write %[[ADD0]], %[[arg7]][%[[arg2]], %[[arg4]], %[[arg6]]] {in_bounds = [true]} : vector<16xf32>, tensor<16x32x64xf32> +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C16]] step %[[C16]] iter_args(%[[arg3:.*]] = %[[arg1]]) -> (tensor<16xf32>) +// CHECK: %[[READ1:.*]] = vector.transfer_read %[[arg3]][%[[arg2]]], %[[CST_0]] {in_bounds = [true]} : tensor<16xf32>, vector<16xf32> +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[READ1]]) -> (vector<16xf32>) +// CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[CST]]) -> (vector<16xf32>) +// CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (vector<16xf32>) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}[%[[arg2]], %[[arg6]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ2]], %[[arg9]] : vector<16xf32> +// CHECK: %[[REDUCTION:.*]] = vector.reduction , {{.*}} : vector<16xf32> into f32 +// CHECK: %[[INSERT:.*]] = vector.insert %[[REDUCTION]], %[[arg5]] [%[[arg4]]] : f32 into vector<16xf32> +// CHECK: %[[MUL:.*]] = arith.mulf {{.*}}, {{.*}} : vector<16xf32> +// CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, %[[arg3]][%[[arg2]]] {in_bounds = [true]} : vector<16xf32>, tensor<16xf32> +func.func @reduce_fuse_test13(%input: tensor<16x32x64xf32>, + %init: tensor<16xf32>) -> tensor<16xf32> { + %0 = linalg.add ins(%input, %input : tensor<16x32x64xf32>,tensor<16x32x64xf32>) + outs(%input : tensor<16x32x64xf32>) -> tensor<16x32x64xf32> + %reduce = linalg.reduce + ins(%0:tensor<16x32x64xf32>) + outs(%init:tensor<16xf32>) + dimensions = [1, 2] + (%in: f32, %out: f32) { + %2 = arith.addf %out, %in: f32 + linalg.yield %2: f32 + } + %1 = linalg.mul ins(%reduce, %reduce : tensor<16xf32>, tensor<16xf32>) outs(%init: tensor<16xf32>) -> tensor<16xf32> + func.return %1 : tensor<16xf32> +} + +// CHECK-LABEL: func @add_small_tensor_test14 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C0:.*]] = arith.constant 0 : index @@ -612,7 +654,7 @@ func.func @reduce_fuse_test12(%input: tensor<16x32x64xf32>, // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<1xf32> // CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], %[[CST]] : vector<1xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<1xf32>, tensor<2xf32> -func.func @add_small_tensor_test13(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { +func.func @add_small_tensor_test14(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { %0 = tensor.empty() : tensor<2xf32> %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> %1 = linalg.add ins(%arg0, %arg1 : tensor<2xf32>, tensor<2xf32>) outs(%0: tensor<2xf32>) -> tensor<2xf32> From a79147c5eeb6608f4493b30e480c40d94f19c882 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 13 Sep 2024 15:23:30 +0800 Subject: [PATCH 51/66] add utils.cpp --- include/gc/Transforms/Utils/VectorUtils.h | 86 ++++++ lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 45 ++-- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 255 +++--------------- lib/gc/Transforms/TilingVector.hpp | 3 +- lib/gc/Transforms/Utils/CMakeLists.txt | 1 + lib/gc/Transforms/Utils/VectorUtils.cpp | 164 +++++++++++ 6 files changed, 309 insertions(+), 245 deletions(-) create mode 100644 include/gc/Transforms/Utils/VectorUtils.h create mode 100644 lib/gc/Transforms/Utils/VectorUtils.cpp diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h new file mode 100644 index 000000000..550d0af1e --- /dev/null +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -0,0 +1,86 @@ +//===-- VectorUtils.h ----- vector fusion analysis --------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_TRANSFORMS_UTILS_VECTORUTILS_H +#define GC_TRANSFORMS_UTILS_VECTORUTILS_H +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include +#include +#include + +namespace mlir { +namespace gc { +union Float32Bits { + uint32_t u; + float f; +}; +uint16_t float2half(float floatValue); +float half2float(uint16_t halfValue); +uint16_t float2bfloat(float floatValue); +float bfloat2float(uint16_t bfloatBits); +std::variant numeric_limits_minimum(Type type); +std::variant numericLimitsMaximum(Type type); + +template +T getInitValForReduce(vector::CombiningKind kind, Type t) { + T result; + Type t1 = getElementTypeOrSelf(t); + + switch (kind) { + case vector::CombiningKind::ADD: + if (t1.isIntOrIndex()) + result = 0; + else if (isa(t1)) + result = 0.0f; + else + llvm_unreachable("invalid value types for ADD reduction"); + break; + case vector::CombiningKind::MAXNUMF: + case vector::CombiningKind::MAXIMUMF: + if (not isa(t1)) + llvm_unreachable("Expected float values."); + result = std::get(numeric_limits_minimum(t)); + break; + case vector::CombiningKind::MINNUMF: + case vector::CombiningKind::MINIMUMF: + if (not isa(t1)) + llvm_unreachable("Expected float values."); + result = std::get(numericLimitsMaximum(t)); + break; + case vector::CombiningKind::MAXSI: + case vector::CombiningKind::MAXUI: + if (not t1.isIntOrIndex()) + llvm_unreachable("Expected int or index values."); + result = std::get(numeric_limits_minimum(t)); + break; + case vector::CombiningKind::MINSI: + case vector::CombiningKind::MINUI: + if (not t1.isIntOrIndex()) + llvm_unreachable("Expected int or index values."); + result = std::get(numericLimitsMaximum(t)); + break; + case vector::CombiningKind::MUL: + if (t1.isIntOrIndex()) + result = 1; + else if (isa(t1)) + result = 1.f; + else + llvm_unreachable("invalid value types for MUL reduction"); + break; + default: + llvm_unreachable("unsupported reduction kind"); + }; + return result; +} + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index 2760ada39..fb7ee4048 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// #include "gc/Analysis/VectorBasedFusionAnalysis.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "gc/Dialect/Linalgx/Utils.h" namespace mlir { namespace gc { @@ -22,16 +22,16 @@ namespace gc { arith::TruncFOp, arith::TruncIOp #define NOT_NEED_TO_PROCESS_OP \ - linalg::GenericOp, linalg::BatchReduceMatmulOp, linalg::MatmulOp, \ - linalg::BatchMatmulOp, linalg::BatchMatmulTransposeAOp, \ - linalg::BatchMatmulTransposeBOp, linalg::MatmulTransposeAOp, \ - linalg::MatmulTransposeBOp, linalg::QuantizedBatchMatmulOp, \ - linalg::QuantizedMatmulOp, tensor::CollapseShapeOp, \ - tensor::ExpandShapeOp, tensor::ExtractSliceOp, tensor::InsertSliceOp, \ - microkernel::BrgemmOp + linalg::BatchReduceMatmulOp, linalg::MatmulOp, linalg::BatchMatmulOp, \ + linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp, \ + linalg::MatmulTransposeAOp, linalg::MatmulTransposeBOp, \ + linalg::QuantizedBatchMatmulOp, linalg::QuantizedMatmulOp, \ + tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::ExtractSliceOp, \ + tensor::InsertSliceOp, microkernel::BrgemmOp static inline bool isNotNeedToProcessOp(Operation *op) { - return isa(op); + return isa(op) or + linalgx::isAnyGenericPackedMatmulOp(op); } static inline bool isSpecialOp(Operation *op) { @@ -72,7 +72,7 @@ void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, while (dimB < dimA && j < rankB) dimB *= b[j++]; if (dimA != dimB) { - assert(false && " Invalid shape cast operation."); + llvm::llvm_unreachable_internal(" Invalid shape cast operation."); break; } if (bAxisBegin != j) { @@ -87,12 +87,13 @@ void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, if (j < rankB && all_of(b.slice(j), isOne)) j = rankB; } - - assert(i == rankA && j == rankB && "Invalid shapecast operation."); + if (i != rankA or j != rankB) + llvm_unreachable("Invalid shapecast operation."); } bool isScalar(Type type) { - assert(type && "Not a valid type"); + if (not type) + llvm_unreachable("Not a valid type"); if (auto vecType = dyn_cast(type)) return false; if (auto tensorType = dyn_cast(type)) @@ -107,8 +108,8 @@ void getSrcBroadcastDim(const ShapedType &input, const ShapedType &output, // following auto_broadcast semantics const size_t input_rank = inputShape.size(); const size_t output_rank = outputShape.size(); - assert(output_rank >= input_rank && - "Incorrect input or output shape for broadcast op."); + if (output_rank < input_rank) + llvm_unreachable("Incorrect input or output shape for broadcast op."); const size_t offset = output_rank - input_rank; for (size_t i = 0; i < input_rank; ++i) { if (inputShape[i] == outputShape[i + offset] || @@ -390,13 +391,16 @@ mlir::FailureOr getOperationMaxVectorType(Operation *op) { /// select nearest even step int getNearestVectorStep(const int step) { - assert(step > 0); + if (step <= 0) + llvm_unreachable("Wrong step."); + int nbits = 0, n = step; while (n) { n = n >> 1; nbits++; } - assert(nbits <= 6 || (nbits == 7 && step == 64)); + if (nbits > 6 and !(nbits == 7 && step == 64)) + llvm_unreachable("wrong nbits appear"); return (1 << (nbits - 1)) == step ? step : (1 << nbits); } @@ -488,7 +492,7 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) { // down into a loop. mlir::FailureOr baseType = getOperationVectorType(op); if (failed(baseType)) { - assert(0 && "Failed to get vector type for operation"); + llvm_unreachable("Failed to get vector type for operation"); return VectorType(); } auto vectorizedType = baseType.value(); @@ -518,7 +522,7 @@ int TypeHelper::generateValidSteps(int steps, VectorType type) { return favx2bits / typebits; // invalid hardware - assert(false && "Invalid hardware."); + llvm_unreachable("Invalid hardware."); return -1; } @@ -590,7 +594,8 @@ void GroupOperationFusion::updateGroupBigestVectorType(VectorType vectorType) { } void GroupOperationFusion::addOperationToGroup(Operation *op) { - assert(op); + if (not op) + llvm_unreachable("Op can't be NULL."); VectorType vectorType = getOperationMaxVectorType(op).value(); if (isNeedNewGroup(op)) opGroups.emplace_back(std::queue()); diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 888af1c79..cda26dd6e 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -47,7 +47,7 @@ void printGroupOps(SmallVector, 8> &opGroups) { } } -static inline bool isUsedByOtherOp(Operation *op) { +static inline bool isProducerOp(Operation *op) { return isa(op); } @@ -77,7 +77,9 @@ static inline bool isSpecialOp(Operation *op) { static inline void moveOpBeginingOfBlock(Operation *op) { Block *block = op->getBlock(); - assert(not block->getOperations().empty() && "Empty block."); + if (block->getOperations().empty()) + llvm_unreachable("Emtpy block."); + if (&block->front() == op) return; op->moveBefore(&block->front()); @@ -236,11 +238,11 @@ FailureOr createArithSplatConstantOp(IRRewriter &rewriter, return failure(); TypedAttr attr; - if (isa(newOperandType.getElementType())) { + if (isa(newOperandType.getElementType())) getConstantDenseAttr(attr, newOperandType, valueType); - } else { + else getConstantDenseAttr(attr, newOperandType, valueType); - } + return rewriter.create(loc, attr)->getResults()[0]; } @@ -253,211 +255,6 @@ bool needReturnResult(std::pair &retType, retType.second < anchorIdx; } -union Float32Bits { - uint32_t u; - float f; -}; - -const uint32_t kF32MantiBits = 23; -const uint32_t kF32HalfMantiBitDiff = 13; -const uint32_t kF32HalfBitDiff = 16; -const Float32Bits kF32Magic = {113 << kF32MantiBits}; -const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; -const uint32_t kF32BfMantiBitDiff = 16; - -/// Constructs the 16 bit representation for a half precision value from a float -/// value. This implementation is adapted from Eigen. -uint16_t float2half(float floatValue) { - const Float32Bits inf = {255 << kF32MantiBits}; - const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; - const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) - << kF32MantiBits}; - uint32_t signMask = 0x80000000u; - uint16_t halfValue = static_cast(0x0u); - Float32Bits f; - f.f = floatValue; - uint32_t sign = f.u & signMask; - f.u ^= sign; - - if (f.u >= f16max.u) { - const uint32_t halfQnan = 0x7e00; - const uint32_t halfInf = 0x7c00; - // Inf or NaN (all exponent bits set). - halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf - } else { - // (De)normalized number or zero. - if (f.u < kF32Magic.u) { - // The resulting FP16 is subnormal or zero. - // - // Use a magic value to align our 10 mantissa bits at the bottom of the - // float. As long as FP addition is round-to-nearest-even this works. - f.f += denormMagic.f; - - halfValue = static_cast(f.u - denormMagic.u); - } else { - uint32_t mantOdd = - (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. - - // Update exponent, rounding bias part 1. The following expressions are - // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + - // 0xfff`, but without arithmetic overflow. - f.u += 0xc8000fffU; - // Rounding bias part 2. - f.u += mantOdd; - halfValue = static_cast(f.u >> kF32HalfMantiBitDiff); - } - } - - halfValue |= static_cast(sign >> kF32HalfBitDiff); - return halfValue; -} - -/// Converts the 16 bit representation of a half precision value to a float -/// value. This implementation is adapted from Eigen. -float half2float(uint16_t halfValue) { - const uint32_t shiftedExp = - 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. - - // Initialize the float representation with the exponent/mantissa bits. - Float32Bits f = { - static_cast((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; - const uint32_t exp = shiftedExp & f.u; - f.u += kF32HalfExpAdjust; // Adjust the exponent - - // Handle exponent special cases. - if (exp == shiftedExp) { - // Inf/NaN - f.u += kF32HalfExpAdjust; - } else if (exp == 0) { - // Zero/Denormal? - f.u += 1 << kF32MantiBits; - f.f -= kF32Magic.f; - } - - f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. - return f.f; -} - -// Constructs the 16 bit representation for a bfloat value from a float value. -// This implementation is adapted from Eigen. -uint16_t float2bfloat(float floatValue) { - if (std::isnan(floatValue)) - return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; - - Float32Bits floatBits; - floatBits.f = floatValue; - uint16_t bfloatBits; - - // Least significant bit of resulting bfloat. - uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; - uint32_t roundingBias = 0x7fff + lsb; - floatBits.u += roundingBias; - bfloatBits = static_cast(floatBits.u >> kF32BfMantiBitDiff); - return bfloatBits; -} - -// Converts the 16 bit representation of a bfloat value to a float value. This -// implementation is adapted from Eigen. -float bfloat2float(uint16_t bfloatBits) { - Float32Bits floatBits; - floatBits.u = static_cast(bfloatBits) << kF32BfMantiBitDiff; - return floatBits.f; -} - -std::variant numeric_limits_minimum(Type type) { - Type t1 = getElementTypeOrSelf(type); - if (t1.isF32()) { - return -std::numeric_limits::infinity(); - } else if (t1.isBF16()) { - return bfloat2float(float2bfloat(-std::numeric_limits::infinity())); - } else if (t1.isF16()) { - return (float)half2float( - float2half(-std::numeric_limits::infinity())); - } else if (t1.isSignedInteger(8)) { - return int64_t(-128); - } else if (t1.isSignedInteger(32)) { - return int64_t(std::numeric_limits::min()); - } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { - return int64_t(0); - } else { - LDBG("Unsupported data type: " << t1 << "\n"); - assert(0 && "unsupported data type"); - return (int64_t)0; - } -} - -std::variant numericLimitsMaximum(Type type) { - Type t1 = getElementTypeOrSelf(type); - if (t1.isF32()) { - return std::numeric_limits::infinity(); - } else if (t1.isBF16()) { - return bfloat2float(float2bfloat(std::numeric_limits::infinity())); - } else if (t1.isF16()) { - return (float)half2float( - float2half(std::numeric_limits::infinity())); - } else if (t1.isSignedInteger(8)) { - return int64_t(127); - } else if (t1.isSignedInteger(32)) { - return int64_t(std::numeric_limits::max()); - } else if (t1.isSignlessInteger(8)) { - return int64_t(255); - } else if (t1.isSignedInteger(32)) { - return int64_t(std::numeric_limits::max()); - } else { - LDBG("Unsupported data type: " << t1 << "\n"); - assert(0 && "unsupported data type"); - return (int64_t)0; - } -} - -template -T getInitValForReduce(vector::CombiningKind kind, Type t) { - T result; - Type t1 = getElementTypeOrSelf(t); - - switch (kind) { - case vector::CombiningKind::ADD: - if (t1.isIntOrIndex()) - result = 0; - else if (isa(t1)) - result = 0.0f; - else - llvm_unreachable("invalid value types for ADD reduction"); - break; - case vector::CombiningKind::MAXNUMF: - case vector::CombiningKind::MAXIMUMF: - assert(isa(t1) && "expected float values"); - result = std::get(numeric_limits_minimum(t)); - break; - case vector::CombiningKind::MINNUMF: - case vector::CombiningKind::MINIMUMF: - assert(isa(t1) && "expected float values"); - result = std::get(numericLimitsMaximum(t)); - break; - case vector::CombiningKind::MAXSI: - case vector::CombiningKind::MAXUI: - assert(t1.isIntOrIndex() && "expected int values"); - result = std::get(numeric_limits_minimum(t)); - break; - case vector::CombiningKind::MINSI: - case vector::CombiningKind::MINUI: - assert(t1.isIntOrIndex() && "expected int values"); - result = std::get(numericLimitsMaximum(t)); - break; - case vector::CombiningKind::MUL: - if (t1.isIntOrIndex()) - result = 1; - else if (isa(t1)) - result = 1.f; - else - llvm_unreachable("invalid value types for MUL reduction"); - break; - default: - llvm_unreachable("unsupported reduction kind"); - }; - return result; -} - // Since we rewrite transfer_read and transfer_write, the `permutationmap` must // be changed. void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, @@ -465,7 +262,8 @@ void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, const AffineMap &permutationMap) { auto dimExpr = permutationMap.getResults(); auto lastDim = dyn_cast(dimExpr.back()); - assert(isa(lastDim)); + if (not isa(lastDim)) + llvm_unreachable("Must be AffineDimExpr."); SmallVector affineExprs; affineExprs.push_back(lastDim); @@ -677,8 +475,10 @@ void getPrevOps(std::queue &prevOps, void getPostOps(std::queue &postOps, std::queue &opQueue, Operation *currentOp) { // pop multireduction op - assert(currentOp == opQueue.front() && "Current operation is not the front " - "operation of the operation queue."); + if (currentOp != opQueue.front()) + llvm_unreachable( + "Current operation is not the front operation of the operation queue."); + opQueue.pop(); while (!opQueue.empty()) { postOps.push(opQueue.front()); @@ -805,7 +605,8 @@ void ForLoopGenerator::getInitArgsToNextAnchor( for (auto x : curOperands) { if (!visited.contains(x) and opInitArgs.contains(x) and opAnchorPos[cur] > loopHelperParam.anchorIdx) { - assert(loopHelperParam.originalOperandLoopArgsMap.contains(x)); + if (not loopHelperParam.originalOperandLoopArgsMap.contains(x)) + llvm_unreachable("Must contains current value."); int loopStateIdx = loopHelperParam.currentLoopStateIdxMap [loopHelperParam.originalOperandLoopArgsMap[x]]; updateCurrentArgsStatus(loopHelperParam.loopIterArgs, loopStateIdx, @@ -1221,8 +1022,8 @@ scf::ForOp LoopGeneratorImpl::parallelAxisGenerateForLoop( DenseMap &opIndexMap = fusionStrategy.getOpGroupIndexMap(); - assert(opIndexMap.contains(multiReductionOp) && - " Must constains multireduction operation."); + if (not opIndexMap.contains(multiReductionOp)) + llvm_unreachable("Must constains multireduction operation."); size_t opIndex = opIndexMap[multiReductionOp]; SmallVector, 8> &opGroups = @@ -1686,7 +1487,8 @@ scf::ForOp LoopGeneratorImpl::generateShapeCastReadWriteLoop( while ((int64_t)itrIdx < smallType.getRank()) { size_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; - assert(endShape < rank and endShape >= 0 && "Invalid endShape"); + if (endShape >= rank) + llvm_unreachable("Invalid shape."); // skip non corresponding axis // e.g.: vector<32x16x1x32xbf16> -> vector<1x512x32xbf16> while (loopType.getShape()[endShape] > @@ -2095,8 +1897,9 @@ bool ShapeCastCanonicalizer::isReadWriteOnLastDim() { int64_t itrIdx = 0; while (itrIdx < smallRankType.getRank()) { int64_t endShape = getFirstTrueIndex(visitedAxis), dimSize = 1; - assert(endShape < largeRankType.getRank() and endShape >= 0 && - "Invalid endShape"); + if (endShape >= largeRankType.getRank() or endShape < 0) + llvm_unreachable("Invalid endShape."); + // skip non corresponding axis // e.g.: vector<32x16x1x32xbf16> -> vector<1x512x32xbf16> while (largeRankType.getShape()[endShape] > @@ -2301,7 +2104,9 @@ void ForLoopGenerator::setOperationCorrectOperand( int offset = isa(op) ? 2 : 1; if (dyn_cast(op) || dyn_cast(op)) { - assert(opPermuationMap.contains(op)); + if (not opPermuationMap.contains(op)) + llvm_unreachable("Map must contains operation."); + auto permutationMap = opPermuationMap.at(op); auto dimExpr = permutationMap.getResults(); @@ -2459,7 +2264,7 @@ Value setOutGroupOperationOperandResult(Operation *op, } else { // write original vector into tensor // then we transfer_read from the tensor - assert(0 && "Not support non-splat constant value."); + llvm_unreachable("Not support non-splat constant value."); } } else if (isa(resultElementType)) { initValueAttr = FloatAttr::get( @@ -2614,7 +2419,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( }); if (failed(lowerResult)) { LDBG("Failed to rewrite operation: " << *op << "\n"); - assert(false && "Failed to rewrite operation"); + llvm_unreachable("Failed to rewrite operation"); } } } @@ -2763,7 +2568,9 @@ void GroupOperationFusionImpl::specialOperationRectify( // remain transfer read operation to do the broadcast fusion if (isa(op)) { auto srcOp = op->getOperand(0).getDefiningOp(); - assert(isa(srcOp)); + if (not isa(srcOp)) + llvm_unreachable("Must be read operation."); + // only have write operation, otherwise the group size will bigger // than 1. Because the last operation is always a write operation in // each group @@ -3221,7 +3028,7 @@ struct CPUPhysicalRegisterPass return; } // affineApply operation is always used by other operations. - std::function candidateFunc = isUsedByOtherOp; + std::function candidateFunc = isProducerOp; moveSomeInterferenceOperation(&func, ctx, candidateFunc); candidateFunc = isCandidateMoveOperations; moveSomeInterferenceOperation(&func, ctx, candidateFunc); diff --git a/lib/gc/Transforms/TilingVector.hpp b/lib/gc/Transforms/TilingVector.hpp index 0545325e6..b0d0fea37 100644 --- a/lib/gc/Transforms/TilingVector.hpp +++ b/lib/gc/Transforms/TilingVector.hpp @@ -1,4 +1,4 @@ -//===- TilingVector.h - Tiling large vector to small vector -----*- C++ -*-===// +//===- TilingVector.hpp - Tiling large vector to small vector ---*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,6 +10,7 @@ #include "gc/Analysis//VectorBasedFusionAnalysis.h" #include "gc/Analysis/TargetDescriptionAnalysis.h" +#include "gc/Transforms/Utils/VectorUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/lib/gc/Transforms/Utils/CMakeLists.txt b/lib/gc/Transforms/Utils/CMakeLists.txt index 94b700435..4742a37ac 100644 --- a/lib/gc/Transforms/Utils/CMakeLists.txt +++ b/lib/gc/Transforms/Utils/CMakeLists.txt @@ -2,6 +2,7 @@ gc_add_mlir_library(GcUtilsIR MatcherUtils.cpp StructuredOpMatcher.cpp ValueUtils.cpp + VectorUtils.cpp DEPENDS MLIRLinalgDialect diff --git a/lib/gc/Transforms/Utils/VectorUtils.cpp b/lib/gc/Transforms/Utils/VectorUtils.cpp new file mode 100644 index 000000000..ae169c878 --- /dev/null +++ b/lib/gc/Transforms/Utils/VectorUtils.cpp @@ -0,0 +1,164 @@ +//===- VectorUtils.cpp - analysis vector ops --------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "gc/Transforms/Utils/VectorUtils.h" + +namespace mlir { +namespace gc { + +const uint32_t kF32MantiBits = 23; +const uint32_t kF32HalfMantiBitDiff = 13; +const uint32_t kF32HalfBitDiff = 16; +const Float32Bits kF32Magic = {113 << kF32MantiBits}; +const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; +const uint32_t kF32BfMantiBitDiff = 16; + +/// Constructs the 16 bit representation for a half precision value from a float +/// value. This implementation is adapted from Eigen. +uint16_t float2half(float floatValue) { + const Float32Bits inf = {255 << kF32MantiBits}; + const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; + const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) + << kF32MantiBits}; + uint32_t signMask = 0x80000000u; + uint16_t halfValue = static_cast(0x0u); + Float32Bits f; + f.f = floatValue; + uint32_t sign = f.u & signMask; + f.u ^= sign; + + if (f.u >= f16max.u) { + const uint32_t halfQnan = 0x7e00; + const uint32_t halfInf = 0x7c00; + // Inf or NaN (all exponent bits set). + halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf + } else { + // (De)normalized number or zero. + if (f.u < kF32Magic.u) { + // The resulting FP16 is subnormal or zero. + // + // Use a magic value to align our 10 mantissa bits at the bottom of the + // float. As long as FP addition is round-to-nearest-even this works. + f.f += denormMagic.f; + + halfValue = static_cast(f.u - denormMagic.u); + } else { + uint32_t mantOdd = + (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. + + // Update exponent, rounding bias part 1. The following expressions are + // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + + // 0xfff`, but without arithmetic overflow. + f.u += 0xc8000fffU; + // Rounding bias part 2. + f.u += mantOdd; + halfValue = static_cast(f.u >> kF32HalfMantiBitDiff); + } + } + + halfValue |= static_cast(sign >> kF32HalfBitDiff); + return halfValue; +} + +/// Converts the 16 bit representation of a half precision value to a float +/// value. This implementation is adapted from Eigen. +float half2float(uint16_t halfValue) { + const uint32_t shiftedExp = + 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. + + // Initialize the float representation with the exponent/mantissa bits. + Float32Bits f = { + static_cast((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; + const uint32_t exp = shiftedExp & f.u; + f.u += kF32HalfExpAdjust; // Adjust the exponent + + // Handle exponent special cases. + if (exp == shiftedExp) { + // Inf/NaN + f.u += kF32HalfExpAdjust; + } else if (exp == 0) { + // Zero/Denormal? + f.u += 1 << kF32MantiBits; + f.f -= kF32Magic.f; + } + + f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. + return f.f; +} + +// Constructs the 16 bit representation for a bfloat value from a float value. +// This implementation is adapted from Eigen. +uint16_t float2bfloat(float floatValue) { + if (std::isnan(floatValue)) + return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; + + Float32Bits floatBits; + floatBits.f = floatValue; + uint16_t bfloatBits; + + // Least significant bit of resulting bfloat. + uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; + uint32_t roundingBias = 0x7fff + lsb; + floatBits.u += roundingBias; + bfloatBits = static_cast(floatBits.u >> kF32BfMantiBitDiff); + return bfloatBits; +} + +// Converts the 16 bit representation of a bfloat value to a float value. This +// implementation is adapted from Eigen. +float bfloat2float(uint16_t bfloatBits) { + Float32Bits floatBits; + floatBits.u = static_cast(bfloatBits) << kF32BfMantiBitDiff; + return floatBits.f; +} + +std::variant numeric_limits_minimum(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return -std::numeric_limits::infinity(); + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(-std::numeric_limits::infinity())); + } else if (t1.isF16()) { + return (float)half2float( + float2half(-std::numeric_limits::infinity())); + } else if (t1.isSignedInteger(8)) { + return int64_t(-128); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::min()); + } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { + return int64_t(0); + } else { + llvm_unreachable("unsupported data type"); + return (int64_t)0; + } +} + +std::variant numericLimitsMaximum(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return std::numeric_limits::infinity(); + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(std::numeric_limits::infinity())); + } else if (t1.isF16()) { + return (float)half2float( + float2half(std::numeric_limits::infinity())); + } else if (t1.isSignedInteger(8)) { + return int64_t(127); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::max()); + } else if (t1.isSignlessInteger(8)) { + return int64_t(255); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::max()); + } else { + llvm_unreachable("unsupported data type"); + return (int64_t)0; + } +} + +} // namespace gc +} // namespace mlir \ No newline at end of file From 3552eb662112c659c573a659fd7d84ed426050b1 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 13 Sep 2024 15:25:16 +0800 Subject: [PATCH 52/66] fix license --- include/gc/Transforms/Utils/VectorUtils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h index 550d0af1e..328d3fcef 100644 --- a/include/gc/Transforms/Utils/VectorUtils.h +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -1,4 +1,4 @@ -//===-- VectorUtils.h ----- vector fusion analysis --------------*- C++ -*-===// +//===-- VectorUtils.h - vector fusion analysis ------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From 3ef9f3ad1f3e5fa29c9884ed19938740104709c3 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 13 Sep 2024 15:38:42 +0800 Subject: [PATCH 53/66] fix code stype --- lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 2 +- lib/gc/Transforms/TilingVector.hpp | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index fb7ee4048..b8f3b6e36 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -399,7 +399,7 @@ int getNearestVectorStep(const int step) { n = n >> 1; nbits++; } - if (nbits > 6 and !(nbits == 7 && step == 64)) + if (nbits > 6 and (nbits != 7 or step != 64)) llvm_unreachable("wrong nbits appear"); return (1 << (nbits - 1)) == step ? step : (1 << nbits); } diff --git a/lib/gc/Transforms/TilingVector.hpp b/lib/gc/Transforms/TilingVector.hpp index b0d0fea37..045a6c677 100644 --- a/lib/gc/Transforms/TilingVector.hpp +++ b/lib/gc/Transforms/TilingVector.hpp @@ -276,7 +276,7 @@ class BroadcastCanonicalizer const SmallVector &candidateBcOps, size_t steps = 1) : SpecialOperationCanonicalizer( - candidateBcOps, SpecialOperationKind::OP_Broadcast, steps) {}; + candidateBcOps, SpecialOperationKind::OP_Broadcast, steps){}; virtual ~BroadcastCanonicalizer() noexcept {} void prepareSpecialOperationInfo() override {} static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -295,9 +295,9 @@ class TransposeCanonicalizer const llvm::SmallVector &candidateTpOps, size_t steps = 1) : SpecialOperationCanonicalizer( - candidateTpOps, SpecialOperationKind::OP_Transpose, steps) {}; + candidateTpOps, SpecialOperationKind::OP_Transpose, steps){}; virtual ~TransposeCanonicalizer() noexcept {} - void prepareSpecialOperationInfo() override {}; + void prepareSpecialOperationInfo() override{}; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose; } @@ -324,7 +324,7 @@ class ShapeCastCanonicalizer const SmallVector &candidateScOps, size_t steps = 1) : SpecialOperationCanonicalizer( - candidateScOps, SpecialOperationKind::OP_ShapeCast, steps) {}; + candidateScOps, SpecialOperationKind::OP_ShapeCast, steps){}; virtual ~ShapeCastCanonicalizer() {} void prepareSpecialOperationInfo() override {} static bool classof(SpecialOperationCanonicalizer *canonicalizer) { @@ -460,7 +460,7 @@ class LoopGeneratorImpl : public ForLoopGenerator { SmallVector shapeCastCanonicalizers; public: - LoopGeneratorImpl(GroupOperationFusion &fusion) : ForLoopGenerator(fusion) {}; + LoopGeneratorImpl(GroupOperationFusion &fusion) : ForLoopGenerator(fusion){}; virtual ~LoopGeneratorImpl() noexcept {}; From a1f9988a027460a9fba1d9008934da5b155fe75a Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 13 Sep 2024 15:50:57 +0800 Subject: [PATCH 54/66] enable broadcast op fusion --- lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index b8f3b6e36..4a2b12cbe 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -207,9 +207,6 @@ bool hasDataDependency(Operation *op1, Operation *op2) { if (isa(op2)) return true; - if (isa(op1)) - return true; - // only special operation may cause data dependency if (!isSpecialOp(op1)) return hasDataDependency(op2, op1); From 47687f1c343eae1fe5830d98db21ce6b7c458807 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 13 Sep 2024 16:12:10 +0800 Subject: [PATCH 55/66] add lower to vector pass in pipeline --- .../gc/Analysis/VectorBasedFusionAnalysis.h | 10 +- include/gc/Transforms/Utils/VectorUtils.h | 32 ++++ lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 124 --------------- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 38 ----- lib/gc/Transforms/Pipeline.cpp | 2 + lib/gc/Transforms/TilingVector.hpp | 19 --- lib/gc/Transforms/Utils/VectorUtils.cpp | 149 ++++++++++++++++++ 7 files changed, 186 insertions(+), 188 deletions(-) diff --git a/include/gc/Analysis/VectorBasedFusionAnalysis.h b/include/gc/Analysis/VectorBasedFusionAnalysis.h index 8a0d354c6..92c9abf22 100644 --- a/include/gc/Analysis/VectorBasedFusionAnalysis.h +++ b/include/gc/Analysis/VectorBasedFusionAnalysis.h @@ -12,6 +12,7 @@ #include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "gc/Transforms/Passes.h" +#include "gc/Transforms/Utils/VectorUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -25,11 +26,6 @@ namespace mlir { namespace gc { -mlir::FailureOr getOperationVectorType(Operation *op, - bool isPrevOp = true); -int getNearestVectorStep(const int step); -mlir::FailureOr getOperationMaxVectorType(Operation *op); - /// record hardware information struct HardWareInfo { bool favx512f = true; @@ -136,7 +132,7 @@ class GroupOperationFusion : public VectorFusionBase { : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()), opGroups(strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps), opGroupIndexMap(strategy.opGroupIndexMap), - opAnchorPos(strategy.opAnchorPos) {}; + opAnchorPos(strategy.opAnchorPos){}; GroupOperationFusion(GroupOperationFusion &&strategy) : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()), @@ -145,7 +141,7 @@ class GroupOperationFusion : public VectorFusionBase { groupBigestRankVectorType( std::move(strategy.getGroupBiggestRankVectorType())), opGroupIndexMap(std::move(strategy.opGroupIndexMap)), - opAnchorPos(std::move(strategy.opAnchorPos)) {}; + opAnchorPos(std::move(strategy.opAnchorPos)){}; GroupOperationFusion &operator=(GroupOperationFusion &fusion) { this->getOpGroups() = fusion.getOpGroups(); diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h index 328d3fcef..a113d5f3a 100644 --- a/include/gc/Transforms/Utils/VectorUtils.h +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -11,12 +11,44 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" #include #include #include namespace mlir { namespace gc { +/// build a constant operation of index type +Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, + int64_t x); + +/// find the original tensor +Value findOriginalTensor(Value writeTensor, Block *block); +/// get operation read or write tensor +mlir::FailureOr getOperationOperateTensor(Operation *op); + +/// set correct operand for the operation +void setOperationCorrectOperand( + Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, + DenseMap &originalOperandLoopArgsMap, + ArrayRef inductionVars, + DenseMap &opPermuationMap); + +/// Get vector type of the operation \param op +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationVectorType(Operation *op, + bool isPrevOp = true); + +/// select nearest even step +int getNearestVectorStep(const int step); + +/// get operation vector type +/// \param isPrevOp whether the operation is a previous operation, if it is not +/// prev-op, may need to use result vectortype +/// default will return the opeation result type +mlir::FailureOr getOperationMaxVectorType(Operation *op); union Float32Bits { uint32_t u; float f; diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index 4a2b12cbe..5cee5314f 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -277,130 +277,6 @@ bool hasDataDependency(Operation *op1, Operation *op2) { return res; } -/// Get vector type of the operation \param op -/// \param isPrevOp whether the operation is a previous operation, if it is not -/// prev-op, may need to use result vectortype -/// default will return the opeation result type -mlir::FailureOr getOperationVectorType(Operation *op, - bool isPrevOp) { - if (not op) - return failure(); - - auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; - auto ret = - TypeSwitch>(op) - .Case( - [&](vector::TransferWriteOp transferWriteOp) - -> mlir::FailureOr { - if (auto retType = dyn_cast( - transferWriteOp.getOperandTypes()[0])) - return retType; - - return failure(); - }) - .Case( - [&](vector::TransferReadOp transferReadOp) - -> mlir::FailureOr { - return transferReadOp.getVectorType(); - }) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - if (isPrevOp) - return cast( - multiReductionOp->getResultTypes()[0]); - - // TODO: may need to add accumulate value vectortype - return cast(multiReductionOp.getSourceVectorType()); - }) - .Default([&](Operation *op) -> mlir::FailureOr { - if (isPrevOp) { - if (op->getResultTypes().empty()) - return failure(); - - if (auto shapedType = - dyn_cast(op->getResultTypes()[0])) - return shapedType; - - return failure(); - } - if (op->getOperandTypes().empty()) - return failure(); - - if (auto shapedType = - dyn_cast(op->getOperandTypes()[0])) - return shapedType; - - return failure(); - }); - if (!failed(ret) and isDynamicType(ret.value())) - return failure(); - - return ret; -} - -/// get operation vector type -/// \param isPrevOp whether the operation is a previous operation, if it is not -/// prev-op, may need to use result vectortype -/// default will return the opeation result type -mlir::FailureOr getOperationMaxVectorType(Operation *op) { - if (not op) - return failure(); - - auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; - auto ret = - TypeSwitch>(op) - .Case( - [&](vector::TransferWriteOp transferWriteOp) - -> mlir::FailureOr { - if (auto retType = - cast(transferWriteOp.getOperandTypes()[0])) - return retType; - return failure(); - }) - .Case( - [&](vector::TransferReadOp transferReadOp) - -> mlir::FailureOr { - return transferReadOp.getVectorType(); - }) - .Case( - [&](vector::MultiDimReductionOp multiReductionOp) { - return cast(multiReductionOp.getSourceVectorType()); - }) - .Default([&](Operation *op) -> mlir::FailureOr { - if (op->getResultTypes().empty() and op->getOperandTypes().empty()) - return failure(); - - if (op->getResultTypes().empty()) - return cast(op->getOperandTypes()[0]); - - if (op->getOperandTypes().empty()) - return cast(op->getResultTypes()[0]); - - auto opdType = cast(op->getOperandTypes()[0]); - auto retType = cast(op->getResultTypes()[0]); - return opdType.getRank() > retType.getRank() ? opdType : retType; - }); - if (!failed(ret) and isDynamicType(ret.value())) - return failure(); - - return ret; -} - -/// select nearest even step -int getNearestVectorStep(const int step) { - if (step <= 0) - llvm_unreachable("Wrong step."); - - int nbits = 0, n = step; - while (n) { - n = n >> 1; - nbits++; - } - if (nbits > 6 and (nbits != 7 or step != 64)) - llvm_unreachable("wrong nbits appear"); - return (1 << (nbits - 1)) == step ? step : (1 << nbits); -} - /// Get the operation which is not a read-write in current queue /// \param [in, out] op Operation *getNotReadWriteOperaiton(std::queue &tmpQ) { diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index cda26dd6e..6d5d85746 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -85,18 +85,6 @@ static inline void moveOpBeginingOfBlock(Operation *op) { op->moveBefore(&block->front()); } -/// find the original tensor -Value findOriginalTensor(Value writeTensor, Block *block) { - while (auto wtOp = dyn_cast_or_null( - writeTensor.getDefiningOp())) { - if (block != writeTensor.getDefiningOp()->getBlock()) - break; - - writeTensor = wtOp->getOperand(1); - } - return writeTensor; -} - /// whether operation is a not support operation bool isNotSupportOperation(Operation *op) { return isa &accRelatedOps, } } -Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, - int64_t x) { - return opBuilder.create( - loc, opBuilder.getIndexType(), - opBuilder.getIntegerAttr(opBuilder.getIndexType(), x)); -} - void ForLoopGenerator::moveOperationsToCurrentForBody( const OpBuilder &b, std::queue &opQueue, GenerateLoopHelper &loopHelperParam) { @@ -2424,25 +2405,6 @@ void ForLoopGenerator::rewriteOperationAsVectorize( } } -mlir::FailureOr getOperationOperateTensor(Operation *op) { - return TypeSwitch>(op) - .Case( - [&](vector::TransferWriteOp transferWriteOp) { - // find original tensor.empty operation - auto writeTensor = transferWriteOp->getOperand(1); - writeTensor = - findOriginalTensor(writeTensor, transferWriteOp->getBlock()); - return writeTensor; - }) - .Case([&](vector::TransferReadOp transferReadOp) { - return transferReadOp->getOperand(0); - }) - .Default([&](Operation *op) { - LDBG("Try to get not DPS operation inits: " << *op << "\n"); - return failure(); - }); -} - void GroupOperationFusionImpl::removeOpInCurrentGroups(size_t grpIdx, Operation *op, Operation *replacedOp) { diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 0c9bb6322..666dbb93f 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -150,6 +150,8 @@ void populateCPURuntimePasses(mlir::OpPassManager &pm) { } void populateLoweringToLLVMPasses(mlir::OpPassManager &pm) { + pm.addPass(createConvertVectorToSCFPass()); + pm.addPass(createConvertVectorToLLVMPass()); pm.addPass(createLowerAffinePass()); pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addPass(createConvertSCFToCFPass()); diff --git a/lib/gc/Transforms/TilingVector.hpp b/lib/gc/Transforms/TilingVector.hpp index 045a6c677..6e399a6dd 100644 --- a/lib/gc/Transforms/TilingVector.hpp +++ b/lib/gc/Transforms/TilingVector.hpp @@ -10,7 +10,6 @@ #include "gc/Analysis//VectorBasedFusionAnalysis.h" #include "gc/Analysis/TargetDescriptionAnalysis.h" -#include "gc/Transforms/Utils/VectorUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,7 +20,6 @@ #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" -#include "mlir/ExecutionEngine/Float16bits.h" #include "mlir/IR/Visitors.h" #include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -30,23 +28,6 @@ namespace mlir { namespace gc { -//===----------------------------------------------------------------------===// -// helper function -//===----------------------------------------------------------------------===// - -/// build a constant operation of index type -Value makeIndexArithConstantOp(OpBuilder &opBuilder, Location &loc, int64_t x); - -/// get operation read or write tensor -mlir::FailureOr getOperationOperateTensor(Operation *op); - -/// set correct operand for the operation -void setOperationCorrectOperand( - Operation *op, ValueRange iterArgs, DenseMap &operandIdxMap, - DenseMap &originalOperandLoopArgsMap, - ArrayRef inductionVars, - DenseMap &opPermuationMap); - /// get fusion kind /// Has two kind: /// 1. OperationGroup: diff --git a/lib/gc/Transforms/Utils/VectorUtils.cpp b/lib/gc/Transforms/Utils/VectorUtils.cpp index ae169c878..949670303 100644 --- a/lib/gc/Transforms/Utils/VectorUtils.cpp +++ b/lib/gc/Transforms/Utils/VectorUtils.cpp @@ -160,5 +160,154 @@ std::variant numericLimitsMaximum(Type type) { } } +mlir::FailureOr getOperationVectorType(Operation *op, + bool isPrevOp) { + if (not op) + return failure(); + + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + if (auto retType = dyn_cast( + transferWriteOp.getOperandTypes()[0])) + return retType; + + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + if (isPrevOp) + return cast( + multiReductionOp->getResultTypes()[0]); + + // TODO: may need to add accumulate value vectortype + return cast(multiReductionOp.getSourceVectorType()); + }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (isPrevOp) { + if (op->getResultTypes().empty()) + return failure(); + + if (auto shapedType = + dyn_cast(op->getResultTypes()[0])) + return shapedType; + + return failure(); + } + if (op->getOperandTypes().empty()) + return failure(); + + if (auto shapedType = + dyn_cast(op->getOperandTypes()[0])) + return shapedType; + + return failure(); + }); + if (!failed(ret) and isDynamicType(ret.value())) + return failure(); + + return ret; +} + +mlir::FailureOr getOperationMaxVectorType(Operation *op) { + if (not op) + return failure(); + + auto isDynamicType = [](VectorType &type) { return !type.hasStaticShape(); }; + auto ret = + TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) + -> mlir::FailureOr { + if (auto retType = + cast(transferWriteOp.getOperandTypes()[0])) + return retType; + return failure(); + }) + .Case( + [&](vector::TransferReadOp transferReadOp) + -> mlir::FailureOr { + return transferReadOp.getVectorType(); + }) + .Case( + [&](vector::MultiDimReductionOp multiReductionOp) { + return cast(multiReductionOp.getSourceVectorType()); + }) + .Default([&](Operation *op) -> mlir::FailureOr { + if (op->getResultTypes().empty() and op->getOperandTypes().empty()) + return failure(); + + if (op->getResultTypes().empty()) + return cast(op->getOperandTypes()[0]); + + if (op->getOperandTypes().empty()) + return cast(op->getResultTypes()[0]); + + auto opdType = cast(op->getOperandTypes()[0]); + auto retType = cast(op->getResultTypes()[0]); + return opdType.getRank() > retType.getRank() ? opdType : retType; + }); + if (!failed(ret) and isDynamicType(ret.value())) + return failure(); + + return ret; +} + +int getNearestVectorStep(const int step) { + if (step <= 0) + llvm_unreachable("Wrong step."); + + int nbits = 0, n = step; + while (n) { + n = n >> 1; + nbits++; + } + if (nbits > 6 and (nbits != 7 or step != 64)) + llvm_unreachable("wrong nbits appear"); + return (1 << (nbits - 1)) == step ? step : (1 << nbits); +} + +Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, + int64_t x) { + return opBuilder.create( + loc, opBuilder.getIndexType(), + opBuilder.getIntegerAttr(opBuilder.getIndexType(), x)); +} + +Value findOriginalTensor(Value writeTensor, Block *block) { + while (auto wtOp = dyn_cast_or_null( + writeTensor.getDefiningOp())) { + if (block != writeTensor.getDefiningOp()->getBlock()) + break; + + writeTensor = wtOp->getOperand(1); + } + return writeTensor; +} + +mlir::FailureOr getOperationOperateTensor(Operation *op) { + return TypeSwitch>(op) + .Case( + [&](vector::TransferWriteOp transferWriteOp) { + // find original tensor.empty operation + auto writeTensor = transferWriteOp->getOperand(1); + writeTensor = + findOriginalTensor(writeTensor, transferWriteOp->getBlock()); + return writeTensor; + }) + .Case([&](vector::TransferReadOp transferReadOp) { + return transferReadOp->getOperand(0); + }) + .Default([&](Operation *op) { return failure(); }); +} + } // namespace gc } // namespace mlir \ No newline at end of file From 1a2a9a1af12688b8ed55b5aa6a1243e457e91888 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 13 Sep 2024 17:35:01 +0800 Subject: [PATCH 56/66] use linalgx utils function --- lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index 5cee5314f..b8a0924cc 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -30,8 +30,7 @@ namespace gc { tensor::InsertSliceOp, microkernel::BrgemmOp static inline bool isNotNeedToProcessOp(Operation *op) { - return isa(op) or - linalgx::isAnyGenericPackedMatmulOp(op); + return isa(op) or linalgx::isMatmulOp(op); } static inline bool isSpecialOp(Operation *op) { From f96c54404e82dea81d1dd3124729b9994911e33d Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Sat, 14 Sep 2024 22:19:45 +0800 Subject: [PATCH 57/66] temporaty save --- .../gc/Analysis/VectorBasedFusionAnalysis.h | 7 + include/gc/Transforms/Utils/VectorUtils.h | 39 +++ lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 3 + lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 279 ++++++------------ lib/gc/Transforms/Pipeline.cpp | 1 + lib/gc/Transforms/TilingVector.hpp | 4 +- lib/gc/Transforms/Utils/VectorUtils.cpp | 172 ++++++++++- 7 files changed, 318 insertions(+), 187 deletions(-) diff --git a/include/gc/Analysis/VectorBasedFusionAnalysis.h b/include/gc/Analysis/VectorBasedFusionAnalysis.h index 92c9abf22..c9987739b 100644 --- a/include/gc/Analysis/VectorBasedFusionAnalysis.h +++ b/include/gc/Analysis/VectorBasedFusionAnalysis.h @@ -123,6 +123,8 @@ class GroupOperationFusion : public VectorFusionBase { // store read and write operations permutation maps in order to convenient // to replace loop induction var DenseMap opPermuationMap; + /// record operation operand original operate value + DenseMap operandOriginalValue; public: GroupOperationFusion(func::FuncOp &func, HardWareInfo &info) @@ -154,6 +156,7 @@ class GroupOperationFusion : public VectorFusionBase { this->getGroupOpResults() = fusion.getGroupOpResults(); this->getGroupOpInitArgs() = fusion.getGroupOpInitArgs(); this->getOpPermuationMap() = fusion.getOpPermuationMap(); + this->getOperandOriginalValue() = fusion.getOperandOriginalValue(); this->getFunction() = fusion.getFunction(); this->getHardwareInfo() = fusion.getHardwareInfo(); this->getTypeHelper() = fusion.getTypeHelper(); @@ -196,6 +199,10 @@ class GroupOperationFusion : public VectorFusionBase { DenseMap &getOpPermuationMap() noexcept { return opPermuationMap; } + + DenseMap &getOperandOriginalValue() noexcept { + return operandOriginalValue; + } /// set operation groups void setGroupOpResults( const SmallVector< diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h index a113d5f3a..6b2a2bd0f 100644 --- a/include/gc/Transforms/Utils/VectorUtils.h +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -8,16 +8,55 @@ #ifndef GC_TRANSFORMS_UTILS_VECTORUTILS_H #define GC_TRANSFORMS_UTILS_VECTORUTILS_H +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include #include #include namespace mlir { namespace gc { +/// Need to move some operations like extract_slice or insert_slice. +/// Because those operation may interpret our analysis result. e.g.: +/// ``` +/// clang-format off +/// %21 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [true, true]} : +/// tensor<16x16xf32>, vector<16x16xf32> %22 = arith.addf %21, %20 : +/// vector<16x16xf32> %23 = vector.transfer_write %22, %extracted_slice_12[%c0, +/// %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> +/// %inserted_slice_13 = tensor.insert_slice %18 into %arg14[%arg13, 0] [16, 16] +/// [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> %extracted_slice_14 = +/// tensor.extract_slice %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<32x16xf32> +/// to tensor<16x16xf32> %24 = vector.transfer_read %cst_0[%c0, %c0], %cst +/// {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> %25 = +/// arith.maximumf %22, %24 : vector<16x16xf32> %26 = vector.transfer_write %25, +/// %extracted_slice_14[%c0, %c0] {in_bounds = [true, true]} : +/// vector<16x16xf32>, tensor<16x16xf32> %inserted_slice_15 = +/// tensor.insert_slice %23 into %arg15[%arg13, 0] [16, 16] [1, 1] : +/// tensor<16x16xf32> into tensor<32x16xf32> %inserted_slice_16 = +/// tensor.insert_slice %26 into %arg16[%arg13, 0] [16, 16] [1, 1] : +/// tensor<16x16xf32> into tensor<32x16xf32> clang-format on +/// ``` +/// The maximumf and addf operation can be a same group, but the extract_slice +/// operation interpret us. +/// The move operation(extra_slice) will check its parameters. In order to +/// ensure that it does not affect the correctness of the result, we will only +/// move the moved op after the op to which the parameters belong to. If it's +/// operand is all the block argument, we will move it to the begining of the +/// block. +/// insert_slice just move them to the privious of the first operation which +/// use it. +void moveSomeInterferenceOperation( + func::FuncOp *func, MLIRContext *ctx, + std::function &conditionalFunc); + /// build a constant operation of index type Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, int64_t x); diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index b8a0924cc..1358080da 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -375,6 +375,9 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) { } int TypeHelper::generateValidSteps(int steps, VectorType type) { + // TODO: support odd shape using mask load store + if (type.getShape().back() & 1) + return 1; if (type.getShape().back() >= steps) return steps; int evenStep = getNearestVectorStep(type.getShape().back()); diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 6d5d85746..0095a78af 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -47,6 +47,10 @@ void printGroupOps(SmallVector, 8> &opGroups) { } } +static inline bool isBroadcastOp(Operation *op) { + return isa_and_nonnull(op); +} + static inline bool isProducerOp(Operation *op) { return isa(op); } @@ -75,16 +79,6 @@ static inline bool isSpecialOp(Operation *op) { op); } -static inline void moveOpBeginingOfBlock(Operation *op) { - Block *block = op->getBlock(); - if (block->getOperations().empty()) - llvm_unreachable("Emtpy block."); - - if (&block->front() == op) - return; - op->moveBefore(&block->front()); -} - /// whether operation is a not support operation bool isNotSupportOperation(Operation *op) { return isa(lastDim)) llvm_unreachable("Must be AffineDimExpr."); - SmallVector affineExprs; - affineExprs.push_back(lastDim); + SmallVector affineExprs(1, lastDim); auto destAffineMap = AffineMap::get(tensorType.getRank(), 0, affineExprs, rewriter.getContext()); - SmallVector inBounds(1, true); + SmallVector inBounds(1, true); if (isa(op)) { auto transferWriteOp = cast(op); transferWriteOp.setPermutationMap(destAffineMap); @@ -270,8 +263,7 @@ void setOpVectorizationPermutationMap(Operation *op, OpBuilder &rewriter, } // scf.for yield helper function -scf::YieldOp maybeYieldValue(OpBuilder &b, Location loc, - const ValueRange &value) { +scf::YieldOp maybeYieldValue(OpBuilder &b, Location loc, ValueRange value) { bool hasRetVal = !value.empty(); if (hasRetVal) return b.create(loc, value); @@ -1303,10 +1295,10 @@ void ForLoopGenerator::replaceOpUsersWithForLoopResult( for (auto x : nextAnchorResults) { auto originalResult = forResultOrignalResultMap[x]; Value forResult = forOp->getResults()[nextAnchorResultsIdxMap[x]]; - rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), forResult, - replaceIfFn); // subsequent group must use the replaced result as operand rectifyGroupOperands(grpIdx, originalResult, forResult); + rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), forResult, + replaceIfFn); } } scf::ForOp @@ -1943,10 +1935,6 @@ void LoopGeneratorImpl::initSpeicalOperationCanonicalizers() { cast(op)); getMultiRdCanonicalizers().back().prepareSpecialOperationInfo(); }) - .Case([&](vector::BroadcastOp broadCastOp) { - getBroadcastCanonicalizers().back().getCandidateOps().emplace_back( - cast(op)); - }) .Case([&](vector::TransposeOp tpOp) { getTransposeCanonicalizers().back().getCandidateOps().emplace_back( cast(op)); @@ -2428,13 +2416,13 @@ void GroupOperationFusionImpl::removeOpInCurrentGroups(size_t grpIdx, IRRewriter rewriter(op); rewriter.replaceOp(op, replacedOp); // update removed operation related operation anchor position - getGroupOperationFusion().getOpAnchorPos()[replacedOp] = - getOperationMaxVectorType(replacedOp)->getRank() - 1; - for (Operation *x : usesOp) - getGroupOperationFusion().getOpAnchorPos()[x] = - getOperationMaxVectorType(x)->getRank() - 1; + // getGroupOperationFusion().getOpAnchorPos()[replacedOp] = + // getOperationMaxVectorType(replacedOp)->getRank() - 1; + // for (Operation *x : usesOp) + // getGroupOperationFusion().getOpAnchorPos()[x] = + // getOperationMaxVectorType(x)->getRank() - 1; - updateOpGroupInfo(grpIdx); + // updateOpGroupInfo(grpIdx); } void GroupOperationFusionImpl::updateOpGroupInfo(size_t grpIdx) { @@ -2639,6 +2627,8 @@ void GroupOperationFusionImpl::makeSourceOpWriteResultToTensor( getGroupOperationFusion().getOpAnchorPos(); SmallVector, 8> &groupOpInitArgs = getGroupOperationFusion().getGroupOpInitArgs(); + SmallVector>, 8> + &groupOpResults = getGroupOperationFusion().getGroupOpResults(); if (!srcOpCanoniclizedMap.contains(sourceOp)) { // get write operation @@ -2647,12 +2637,18 @@ void GroupOperationFusionImpl::makeSourceOpWriteResultToTensor( .getNextTargetOperationInCurrentGroup( sourceOp, sourceOpGid)) { auto writeOpresult = writeOp->getResults()[0]; - auto writeTensor = writeOp->getOperands()[1]; + auto originalWriteTensor = writeOp->getOperands()[1]; // find original tensor.empty operation - writeTensor = findOriginalTensor(writeTensor, sourceOp->getBlock()); + Value writeTensor = + findOriginalTensor(originalWriteTensor, sourceOp->getBlock()); + if (writeTensor != originalWriteTensor) + getGroupOperationFusion() + .getOperandOriginalValue()[originalWriteTensor] = writeTensor; + srcOpCanoniclizedMap.insert({sourceOp, {writeTensor, writeOpresult}}); groupOpInitArgs[sourceOpGid].insert(writeTensor); - updateReturnResultKind(writeOp, sourceOpGid, rtKind); + groupOpResults[sourceOpGid].insert( + {writeOpresult, {rtKind, OpAnchorPos[sourceOp]}}); return; } generateEmptyTensorAndWrite(sourceOp, srcOpCanoniclizedMap, @@ -2683,9 +2679,9 @@ void GroupOperationFusionImpl::GroupOperationReturnResultProcess( if (failed(dstRet)) { // already generate result tensor, special operation do the // transformation by itself - if (isSpecialOp(sourceOp) and inSameGroupNeedReturn) { + if (isSpecialOp(sourceOp) and inSameGroupNeedReturn and + not isBroadcastOp(sourceOp)) return; - } makeSourceOpWriteResultToTensor(sourceOp, sourceOpGid, rtKind); auto opInit = canonicalizeCurrentOperation( op, srcOpCanoniclizedMap[sourceOp].second, operandIdx); @@ -2703,9 +2699,68 @@ void GroupOperationFusionImpl::GroupOperationReturnResultProcess( } // transfer write operation groupOpInitArgs[sourceOpGid].insert(dstRet.value()); + auto writeTensor = sourceOp->getOperand(1); + if (dstRet.value() != writeTensor) + getGroupOperationFusion().getOperandOriginalValue()[writeTensor] = + dstRet.value(); + updateReturnResultKind(sourceOp, sourceOpGid, rtKind); } +void GroupOperationFusionImpl::broadcastFromElements(Operation *op, + size_t grpIdx) { + if (not isa(op)) + llvm_unreachable("Must be broadcast operation."); + + if (not isa(op->getOperandTypes()[0])) { + auto inputBcastOp = cast(op); + size_t steps = getGroupOperationFusion().getGroupMaxSteps()[grpIdx]; + IRRewriter rewriter(op); + VectorType newOperandType = + getGroupOperationFusion().getTypeHelper().getVectorzedType(op, steps); + if (isa_and_nonnull(op->getOperand(0).getDefiningOp())) { + auto constantOp = cast(op); + SmallVector shapes(1, steps); + auto dataType = mlir::VectorType::get( + shapes, inputBcastOp.getResultVectorType().getElementType()); + + FailureOr res = createArithSplatConstantOp( + rewriter, op->getLoc(), + DenseElementsAttr::get(dataType, constantOp.getValue()), + newOperandType); + if (failed(res)) + llvm::llvm_unreachable_internal("Wrong to create constant op."); + removeOpInCurrentGroups(grpIdx, op, res.value().getDefiningOp()); + + } else { + auto bcastOp = rewriter.create( + op->getLoc(), newOperandType, op->getOperands()[0]); + removeOpInCurrentGroups(grpIdx, op, bcastOp); + std::function candidateFunc = isBroadcastOp; + moveSomeInterferenceOperation(&getGroupOperationFusion().getFunction(), + op->getContext(), candidateFunc); + } + } +} + +void GroupOperationFusionImpl::scalarOperandFromElements() { + auto &opGroups = getGroupOperationFusion().getOpGroups(); + + for (auto [idx, grp] : llvm::enumerate(opGroups)) { + + std::queue tmpQueue(grp); + while (!tmpQueue.empty()) { + auto op = tmpQueue.front(); + tmpQueue.pop(); + TypeSwitch(op) + .Case([&](vector::BroadcastOp &bcOp) { + broadcastFromElements(bcOp, idx); + }) + .Default([&](Operation *op) { return; }); + } + } +} + void GroupOperationFusionImpl::canonicalizeEachOperationGroup() { // record the operation which has been moved DenseSet movedOperationSet; @@ -2733,8 +2788,8 @@ void GroupOperationFusionImpl::canonicalizeEachOperationGroup() { bool outOfGroup = !opGroupIndexMap.contains(op); // Different anchor in same group and source operation is in inner // loop, we need to get source operation's result - bool inSameGroupNeedReturn = - !notInSameGroup and OpAnchorPos[sourceOp] > OpAnchorPos[op]; + bool inSameGroupNeedReturn = !outOfGroup and !notInSameGroup and + OpAnchorPos[sourceOp] > OpAnchorPos[op]; if (notInSameGroup or outOfGroup or inSameGroupNeedReturn) GroupOperationReturnResultProcess(sourceOpGid, sourceOp, op, idx, @@ -2747,6 +2802,7 @@ void GroupOperationFusionImpl::canonicalizeEachOperationGroup() { } }); analysisEmptyGroup(); + scalarOperandFromElements(); specialOperationRectify(visitedOperation); LDBG("Complete analysis group operation results\n"); } @@ -2756,6 +2812,10 @@ void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, Value forResult) { size_t totalGroupSize = getVectorBasedFusion().getOpGroups().size(); size_t startGroup = currentGroupId; + DenseMap &operandOriginalMap = + getVectorBasedFusion().getOperandOriginalValue(); + if (operandOriginalMap.contains(originalResult)) + originalResult = operandOriginalMap[originalResult]; while (startGroup < totalGroupSize) { SetVector &operandVector = getVectorBasedFusion().getGroupOpInitArgs()[startGroup++]; @@ -2794,11 +2854,9 @@ mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( bool LoopGeneratorImpl::isGroupHasSpecialOperation(const size_t grpIdx) { auto &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; - auto &bcCanonicalizer = getBroadcastCanonicalizers()[grpIdx]; auto &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; auto &shapeCastCanonicalizer = getShapeCastCanonicalizers()[grpIdx]; return !rdCanonicalizer.getCandidateOps().empty() or - !bcCanonicalizer.getCandidateOps().empty() or !tpCanonicalizer.getCandidateOps().empty() or !shapeCastCanonicalizer.getCandidateOps().empty(); } @@ -2810,9 +2868,8 @@ void LoopGeneratorImpl::generateGroupOpVectorizedIR(const int idx) { return; } // TODO: special operation better fusion - if (isGroupHasSpecialOperation(idx)) { + if (isGroupHasSpecialOperation(idx)) return; - } VectorType groupType = getVectorBasedFusion().getGroupBiggestRankVectorType()[idx]; @@ -2829,152 +2886,6 @@ void LoopGeneratorImpl::generateGroupOpVectorizedIR(const int idx) { moveLoopInvariantCode(forOp.value()); } -LogicalResult -moveFront(Operation *op, - llvm::DenseMap &operationPosition) { - IRRewriter rewriter(op); - Operation *backOperation; - size_t pos = 0; - // check all the operand is block argument - bool allBlockArgs = true; - for (auto operand : op->getOperands()) { - if (!isa(operand)) { - allBlockArgs = false; - break; - } - } - if (allBlockArgs) { - moveOpBeginingOfBlock(op); - return success(); - } - for (auto operand : op->getOperands()) { - if (isa(operand)) - continue; - - Operation *sourceOp = operand.getDefiningOp(); - if (operationPosition[sourceOp] > pos and - sourceOp->getBlock() == op->getBlock()) { - backOperation = sourceOp; - pos = operationPosition[sourceOp]; - } - } - if (pos == 0) { - // extract operand operation all in previous block - moveOpBeginingOfBlock(op); - return success(); - } - if (backOperation) { - rewriter.moveOpAfter(op, backOperation); - return success(); - } - return failure(); -} - -LogicalResult moveBack(Operation *op, - llvm::DenseMap &operationPosition) { - IRRewriter rewriter(op); - Operation *firstOperation; - size_t pos = std::numeric_limits::max(); - for (auto user : op->getUsers()) { - if (operationPosition[user] < pos and user->getBlock() == op->getBlock()) { - firstOperation = user; - pos = operationPosition[user]; - } - } - if (pos == std::numeric_limits::max()) { - // Don't move. - // TODO: need to consider move before the block which use it. - return success(); - } - if (firstOperation) { - rewriter.moveOpBefore(op, firstOperation); - return success(); - } - return failure(); -} - -void moveCandidateOperation( - llvm::DenseMap &operationPosition, - ArrayRef candidateOps) { - - for (Operation *op : candidateOps) { - auto ret = - TypeSwitch(op) - .Case([&](affine::AffineApplyOp affineOp) { - return moveFront(op, operationPosition); - }) - .Case( - [&](tensor::ExtractSliceOp extractOp) { - return moveFront(op, operationPosition); - }) - .Case([&](tensor::EmptyOp emptyOp) { - return moveFront(op, operationPosition); - }) - .Case([&](tensor::InsertSliceOp insertOp) { - return moveBack(op, operationPosition); - }) - .Case([&](vector::TransferReadOp readOp) { - return moveFront(op, operationPosition); - }) - .Case( - [&](vector::TransferWriteOp writeOp) { - return moveBack(op, operationPosition); - }) - .Default([&](Operation *op) { return success(); }); - if (failed(ret)) { - LDBG("Wrong to move operation:" << *op << "\n"); - return; - } - } -} - -// Need to move some operations like extract_slice or insert_slice. -// Because those operation may interpret our analysis result. e.g.: -// ``` -// clang-format off - // %21 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> - // %22 = arith.addf %21, %20 : vector<16x16xf32> - // %23 = vector.transfer_write %22, %extracted_slice_12[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> - // %inserted_slice_13 = tensor.insert_slice %18 into %arg14[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> - // %extracted_slice_14 = tensor.extract_slice %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<32x16xf32> to tensor<16x16xf32> - // %24 = vector.transfer_read %cst_0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> - // %25 = arith.maximumf %22, %24 : vector<16x16xf32> - // %26 = vector.transfer_write %25, %extracted_slice_14[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> - // %inserted_slice_15 = tensor.insert_slice %23 into %arg15[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> - // %inserted_slice_16 = tensor.insert_slice %26 into %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> -// clang-format on -// ``` -// The maximumf and addf operation can be a same group, but the extract_slice -// operation interpret us. -// The move operation(extra_slice) will check its parameters. In order to -// ensure that it does not affect the correctness of the result, we will only -// move the moved op after the op to which the parameters belong to. If it's -// operand is all the block argument, we will move it to the begining of the -// block. -// insert_slice just move them to the privious of the first operation which -// use it. -void moveSomeInterferenceOperation( - func::FuncOp *func, MLIRContext *ctx, - std::function &conditionalFunc) { - // Pre-order traversal of each op - // Record each operation position. Inorder to we can kown current operation - // should move after which operation. - DenseMap operationPosition; - SmallVector candidateOps; - size_t opCounter = 0; - - // get the position of each operation - func->walk([&](Operation *op) { - operationPosition[op] = opCounter++; - if (conditionalFunc(op)) - candidateOps.emplace_back(op); - }); - moveCandidateOperation(operationPosition, candidateOps); - // eliminate some useless operation - RewritePatternSet patterns(ctx); - (void)applyPatternsAndFoldGreedily(*func, std::move(patterns)); -} - /// Pass that lower to physical vector. struct CPUPhysicalRegisterPass : public impl::CPUPhysicalRegisterPassBase { diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 666dbb93f..c82436aae 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -90,6 +90,7 @@ void populateVectorPasses(mlir::OpPassManager &pm) { // Bf16 cast elimilation pass pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(createCPUPhysicalRegisterPass()); + pm.addPass(createLoopInvariantCodeMotionPass()); // todo: lower to physical vector pass, device dependent pass populateCleanUpPasses(pm); } diff --git a/lib/gc/Transforms/TilingVector.hpp b/lib/gc/Transforms/TilingVector.hpp index 6e399a6dd..dc54da35e 100644 --- a/lib/gc/Transforms/TilingVector.hpp +++ b/lib/gc/Transforms/TilingVector.hpp @@ -299,7 +299,6 @@ class TransposeCanonicalizer class ShapeCastCanonicalizer : public SpecialOperationCanonicalizer { -private: public: ShapeCastCanonicalizer( const SmallVector &candidateScOps, @@ -546,6 +545,9 @@ class GroupOperationFusionImpl : public GroupOperationAnalysis { GroupOperationFusionImpl(func::FuncOp &func, HardWareInfo &info) : GroupOperationAnalysis(func, info) {} + void broadcastFromElements(Operation *op, size_t grpIdx); + void scalarOperandFromElements(); + /// Generate emtpy tensor and write operations for operations that need to /// return their results, and generate read operations for operations that /// need to read parameters from the block. diff --git a/lib/gc/Transforms/Utils/VectorUtils.cpp b/lib/gc/Transforms/Utils/VectorUtils.cpp index 949670303..e76f78765 100644 --- a/lib/gc/Transforms/Utils/VectorUtils.cpp +++ b/lib/gc/Transforms/Utils/VectorUtils.cpp @@ -6,10 +6,176 @@ // //===----------------------------------------------------------------------===// #include "gc/Transforms/Utils/VectorUtils.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace gc { +#define DEBUG_TYPE "vector-utils" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define SAFE_EXPAND(X) X +#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") + +static inline void moveOpBeginingOfBlock(Operation *op) { + Block *block = op->getBlock(); + if (block->getOperations().empty()) + llvm_unreachable("Emtpy block."); + + if (&block->front() == op) + return; + op->moveBefore(&block->front()); +} + +LogicalResult +moveFront(Operation *op, + llvm::DenseMap &operationPosition) { + IRRewriter rewriter(op); + Operation *backOperation; + size_t pos = 0; + // check all the operand is block argument + bool allBlockArgs = true; + for (auto operand : op->getOperands()) { + if (!isa(operand)) { + allBlockArgs = false; + break; + } + } + if (allBlockArgs) { + moveOpBeginingOfBlock(op); + return success(); + } + for (auto operand : op->getOperands()) { + if (isa(operand)) + continue; + + Operation *sourceOp = operand.getDefiningOp(); + if (operationPosition[sourceOp] > pos and + sourceOp->getBlock() == op->getBlock()) { + backOperation = sourceOp; + pos = operationPosition[sourceOp]; + } + } + if (pos == 0) { + // extract operand operation all in previous block + moveOpBeginingOfBlock(op); + return success(); + } + if (backOperation) { + rewriter.moveOpAfter(op, backOperation); + return success(); + } + return failure(); +} + +LogicalResult moveBack(Operation *op, + llvm::DenseMap &operationPosition) { + IRRewriter rewriter(op); + Operation *firstOperation; + size_t pos = std::numeric_limits::max(); + for (auto user : op->getUsers()) { + if (operationPosition[user] < pos and user->getBlock() == op->getBlock()) { + firstOperation = user; + pos = operationPosition[user]; + } + } + if (pos == std::numeric_limits::max()) { + // Don't move. + // TODO: need to consider move before the block which use it. + return success(); + } + if (firstOperation) { + rewriter.moveOpBefore(op, firstOperation); + return success(); + } + return failure(); +} + +void moveCandidateOperation( + llvm::DenseMap &operationPosition, + ArrayRef candidateOps) { + + for (Operation *op : candidateOps) { + auto ret = + TypeSwitch(op) + .Case([&](affine::AffineApplyOp affineOp) { + return moveFront(op, operationPosition); + }) + .Case( + [&](tensor::ExtractSliceOp extractOp) { + return moveFront(op, operationPosition); + }) + .Case([&](tensor::EmptyOp emptyOp) { + return moveFront(op, operationPosition); + }) + .Case([&](tensor::InsertSliceOp insertOp) { + return moveBack(op, operationPosition); + }) + .Case([&](vector::TransferReadOp readOp) { + return moveFront(op, operationPosition); + }) + .Case( + [&](vector::TransferWriteOp writeOp) { + return moveBack(op, operationPosition); + }) + .Case([&](vector::BroadcastOp bcOp) { + return moveFront(op, operationPosition); + }) + .Default([&](Operation *op) { return success(); }); + if (failed(ret)) { + LDBG("Wrong to move operation:" << *op << "\n"); + return; + } + } +} + +// Need to move some operations like extract_slice or insert_slice. +// Because those operation may interpret our analysis result. e.g.: +// ``` +// clang-format off + // %21 = vector.transfer_read %18[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> + // %22 = arith.addf %21, %20 : vector<16x16xf32> + // %23 = vector.transfer_write %22, %extracted_slice_12[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> + // %inserted_slice_13 = tensor.insert_slice %18 into %arg14[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> + // %extracted_slice_14 = tensor.extract_slice %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<32x16xf32> to tensor<16x16xf32> + // %24 = vector.transfer_read %cst_0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> + // %25 = arith.maximumf %22, %24 : vector<16x16xf32> + // %26 = vector.transfer_write %25, %extracted_slice_14[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> + // %inserted_slice_15 = tensor.insert_slice %23 into %arg15[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> + // %inserted_slice_16 = tensor.insert_slice %26 into %arg16[%arg13, 0] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<32x16xf32> +// clang-format on +// ``` +// The maximumf and addf operation can be a same group, but the extract_slice +// operation interpret us. +// The move operation(extra_slice) will check its parameters. In order to +// ensure that it does not affect the correctness of the result, we will only +// move the moved op after the op to which the parameters belong to. If it's +// operand is all the block argument, we will move it to the begining of the +// block. +// insert_slice just move them to the privious of the first operation which +// use it. +void moveSomeInterferenceOperation( + func::FuncOp *func, MLIRContext *ctx, + std::function &conditionalFunc) { + // Pre-order traversal of each op + // Record each operation position. Inorder to we can kown current operation + // should move after which operation. + DenseMap operationPosition; + SmallVector candidateOps; + size_t opCounter = 0; + + // get the position of each operation + func->walk([&](Operation *op) { + operationPosition[op] = opCounter++; + if (conditionalFunc(op)) + candidateOps.emplace_back(op); + }); + moveCandidateOperation(operationPosition, candidateOps); + // eliminate some useless operation + RewritePatternSet patterns(ctx); + (void)applyPatternsAndFoldGreedily(*func, std::move(patterns)); +} + const uint32_t kF32MantiBits = 23; const uint32_t kF32HalfMantiBitDiff = 13; const uint32_t kF32HalfBitDiff = 16; @@ -245,10 +411,12 @@ mlir::FailureOr getOperationMaxVectorType(Operation *op) { if (op->getResultTypes().empty() and op->getOperandTypes().empty()) return failure(); - if (op->getResultTypes().empty()) + if (op->getResultTypes().empty() or + not isa(op->getResultTypes()[0])) return cast(op->getOperandTypes()[0]); - if (op->getOperandTypes().empty()) + if (op->getOperandTypes().empty() or + not isa(op->getOperandTypes()[0])) return cast(op->getResultTypes()[0]); auto opdType = cast(op->getOperandTypes()[0]); From bec59abd3f5f891381cca50b4d488ee0a9d92b42 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Wed, 18 Sep 2024 16:10:38 +0800 Subject: [PATCH 58/66] rename file name --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 9 +++++---- ...etor-distribution.mlir => cpu-phyaical-register.mlir} | 0 2 files changed, 5 insertions(+), 4 deletions(-) rename test/mlir/test/gc/Transforms/{cpu-vetor-distribution.mlir => cpu-phyaical-register.mlir} (100%) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 0095a78af..db41e8127 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -2744,10 +2744,10 @@ void GroupOperationFusionImpl::broadcastFromElements(Operation *op, } void GroupOperationFusionImpl::scalarOperandFromElements() { - auto &opGroups = getGroupOperationFusion().getOpGroups(); - - for (auto [idx, grp] : llvm::enumerate(opGroups)) { - + SmallVector, 8> &opGroups = + getGroupOperationFusion().getOpGroups(); + size_t idx = 0; + for (auto grp : opGroups) { std::queue tmpQueue(grp); while (!tmpQueue.empty()) { auto op = tmpQueue.front(); @@ -2758,6 +2758,7 @@ void GroupOperationFusionImpl::scalarOperandFromElements() { }) .Default([&](Operation *op) { return; }); } + idx++; } } diff --git a/test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir b/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir similarity index 100% rename from test/mlir/test/gc/Transforms/cpu-vetor-distribution.mlir rename to test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir From 4c901c4672d28393c884fdd5165cfa4182f9a70d Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Wed, 18 Sep 2024 16:15:06 +0800 Subject: [PATCH 59/66] push local change --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index db41e8127..ceb1d6e97 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -2747,7 +2747,7 @@ void GroupOperationFusionImpl::scalarOperandFromElements() { SmallVector, 8> &opGroups = getGroupOperationFusion().getOpGroups(); size_t idx = 0; - for (auto grp : opGroups) { + for (auto &grp : opGroups) { std::queue tmpQueue(grp); while (!tmpQueue.empty()) { auto op = tmpQueue.front(); From 9b6e6c88be52a51eb64b4388c157eacf2ce72347 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 19 Sep 2024 17:38:56 +0800 Subject: [PATCH 60/66] fix reduce loop indice --- .../gc/Analysis/VectorBasedFusionAnalysis.h | 6 +- lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 14 +++- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 83 ++++++++++++++----- lib/gc/Transforms/TilingVector.hpp | 5 +- .../gc/Transforms/cpu-phyaical-register.mlir | 4 +- 5 files changed, 85 insertions(+), 27 deletions(-) diff --git a/include/gc/Analysis/VectorBasedFusionAnalysis.h b/include/gc/Analysis/VectorBasedFusionAnalysis.h index c9987739b..4034effbd 100644 --- a/include/gc/Analysis/VectorBasedFusionAnalysis.h +++ b/include/gc/Analysis/VectorBasedFusionAnalysis.h @@ -48,7 +48,11 @@ class TypeHelper { int getDataTypeValidSteps(VectorType type); /// get vector \param type an even for loop step int generateValidSteps(int steps, VectorType type); - /// get vector \param type max simd length according to hardware information + /// get vector \param type an even for loop step when shape dimension is + /// shapeDim + int generateValidSteps(int steps, VectorType type, int shapeDim); + /// get vector \param type max simd length according to hardware + /// information int getDataTypeMAXSIMDLength(VectorType type); /// get operation's vector type VectorType getVectorzedType(Operation *op, uint32_t loopStep = 0); diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index 1358080da..5bd67b367 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -374,14 +374,24 @@ VectorType TypeHelper::getVectorzedType(Operation *op, uint32_t loopStep) { return VectorType::get({loopStep}, vectorizedType.getElementType()); } +int TypeHelper::generateValidSteps(int steps, VectorType type, int shapeDim) { + if (shapeDim & 1) + return 1; + auto typebits = type.getElementTypeBitWidth(); + if (shapeDim >= steps) + return steps * typebits >= 128 ? steps : 1; + int evenStep = getNearestVectorStep(shapeDim); + return evenStep * typebits >= 128 ? evenStep : 1; +} + int TypeHelper::generateValidSteps(int steps, VectorType type) { // TODO: support odd shape using mask load store if (type.getShape().back() & 1) return 1; + auto typebits = type.getElementTypeBitWidth(); if (type.getShape().back() >= steps) - return steps; + return steps * typebits >= 128 ? steps : 1; int evenStep = getNearestVectorStep(type.getShape().back()); - auto typebits = type.getElementTypeBitWidth(); return evenStep * typebits >= 128 ? evenStep : 1; } diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index ceb1d6e97..d8a5512ab 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -545,12 +545,6 @@ void updateCurrentArgsStatus(ValueRange loopState, const size_t loopStateIdx, DenseMap &nextOriginalOperandMap, DenseMap &nextOperandOriginalMap) { Value currentArgs = loopState[loopStateIdx]; - if (currentArgs.getType() != originalValue.getType()) { - llvm::outs() << loopStateIdx << "," - << "\n"; - currentArgs.dump(); - llvm::llvm_unreachable_internal("Type not equal."); - } nextAnchorArgs.emplace_back(currentArgs); nextAnchorArgsIdxMap[currentArgs] = nextAnchorArgs.size() - 1; nextOriginalOperandMap[originalValue] = currentArgs; @@ -740,6 +734,36 @@ void updateLoopArgsData(Value val, Value originalVal, originalOperandLoopArgsMap[originalVal] = val; } +void LoopGeneratorImpl::rectifyParallelIndice( + GenerateLoopHelper &loopHelperParam, OpBuilder &b, Location loc) { + MultiReductionCanonicalizer rdCanonicalizer = + getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; + auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; + SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); + + // rectify indice of read from source operand + auto sourceReadOp = + multireductionOp.getSource().getDefiningOp(); + if (!sourceReadOp) + return; + + AffineExpr outterParallel, innerParallel; + bindDims(multireductionOp->getContext(), outterParallel, innerParallel); + + Value op = + loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - + reductionAxis.size() - 2]; + Value ip = + loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - + reductionAxis.size() - 1]; + Value newIndice = b.createOrFold( + loc, (outterParallel + innerParallel), ValueRange{op, ip}); + int parallelSize = rdCanonicalizer.getParallelAxis().size(); + int readIndiceOffset = + 1 + rdCanonicalizer.getParallelAxis()[parallelSize - 1]; + sourceReadOp->setOperand(readIndiceOffset, newIndice); +} + scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop( OpBuilder &opBuilder, const size_t reductionIdx, GenerateLoopHelper &loopHelperParam) { @@ -755,18 +779,22 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop( const auto loc = multireductionOp->getLoc(); SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); - bool lastDimReduction = rdCanonicalizer.hasLastDimReduction(); VectorType vectorType = rdCanonicalizer.getSourceType(); - const int loopStep = - getVectorBasedFusion().getGroupMaxSteps()[loopHelperParam.groupIdx]; + auto tpHelper = fusionStrategy.getTypeHelper(); + + int loopStep = tpHelper.generateValidSteps( + fusionStrategy.getTypeHelper().getDataTypeMAXSIMDLength(vectorType), + vectorType, vectorType.getShape()[reductionAxis[reductionIdx]]); + bool isLastDimReduction = rdCanonicalizer.getHasLastDimReduction(); + loopStep = (reductionIdx == reductionAxis.size() - 1 && isLastDimReduction) + ? loopStep + : 1; + func::FuncOp func = fusionStrategy.getFunction(); IRRewriter rewriterOfFunc(func); Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); - Value forSteps = makeIndexArithConstantOp( - opBuilder, loc, - (reductionIdx == reductionAxis.size() - 1 && lastDimReduction) ? loopStep - : 1); + Value forSteps = makeIndexArithConstantOp(opBuilder, loc, loopStep); Value numIter = makeIndexArithConstantOp( opBuilder, loc, vectorType.getShape()[reductionAxis[reductionIdx]]); scf::ForOp forOp = opBuilder.create( @@ -868,9 +896,12 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop( } rewriteOperationAsVectorize(b, loopHelperParam.groupIdx, - &movingOperation); + &movingOperation, + isLastDimReduction ? loopStep : 0); loopHelperParam.loopIterArgs = loopState; moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam); + if (isLastDimReduction) + rectifyParallelIndice(loopHelperParam, b, loc); loopHelperParam.movedOps = &movingOperation; loopHelperParam.candidateOps = &opQueue; @@ -1058,11 +1089,16 @@ scf::ForOp LoopGeneratorImpl::parallelAxisGenerateForLoop( // get accumualte value Attribute initValueAttr; getReductionInitAttr(multiReductionOp, initValueAttr); - + SmallVector &reductionAxis = + rdCanonicalizer.getReductionAxis(); + TypeHelper tpHelper = fusionStrategy.getTypeHelper(); + int loopStep = tpHelper.generateValidSteps( + tpHelper.getDataTypeMAXSIMDLength(vectorType), vectorType, + vectorType.getShape()[reductionAxis[reductionAxis.size() - 1]]); auto accVal = b.create( loc, DenseElementsAttr::get( fusionStrategy.getTypeHelper().getVectorzedType( - multiReductionOp, dimSize), + multiReductionOp, loopStep), {initValueAttr})); // put accumulte val at first for loop args @@ -1247,14 +1283,14 @@ void LoopGeneratorImpl::rearrageMultiReductionIR( DenseMap varLoopIdxMap; VectorType groupVector = getVectorBasedFusion().getGroupBiggestRankVectorType()[grpIdx]; - for (size_t i = 0; i < parallelAxis.size(); i++) { + for (size_t i = 0; i < parallelAxis.size(); i++) varLoopIdxMap[parallelAxis[i]] = i; - } + size_t offset = rdCanonicalizer.hasLastDimReduction() ? 1 : 0; for (size_t i = parallelAxis.size() + offset; - i < groupVector.getRank() + offset; i++) { + i < groupVector.getRank() + offset; i++) varLoopIdxMap[reductionAxis[i - parallelAxis.size() - offset]] = i; - } + while (!tmpSourceQ.empty()) { auto *curOp = tmpSourceQ.front(); tmpSourceQ.pop(); @@ -2313,7 +2349,8 @@ void ForLoopGenerator::createNewConstantOp( /// Rewrite the operations in the group to vectorized form. void ForLoopGenerator::rewriteOperationAsVectorize( - OpBuilder &rewriter, size_t groupId, const std::queue *queue) { + OpBuilder &rewriter, size_t groupId, const std::queue *queue, + const size_t vectorizeStep) { const std::queue groupOps = !queue ? getVectorBasedFusion().getOpGroups()[groupId] : *queue; @@ -2322,7 +2359,9 @@ void ForLoopGenerator::rewriteOperationAsVectorize( DenseMap &opPermuationMap = getVectorBasedFusion().getOpPermuationMap(); std::queue transformQueue(groupOps); - size_t groupSteps = getVectorBasedFusion().getGroupMaxSteps()[groupId]; + size_t groupSteps = vectorizeStep == 0 + ? getVectorBasedFusion().getGroupMaxSteps()[groupId] + : vectorizeStep; while (!transformQueue.empty()) { Operation *op = transformQueue.front(); diff --git a/lib/gc/Transforms/TilingVector.hpp b/lib/gc/Transforms/TilingVector.hpp index dc54da35e..45303a29d 100644 --- a/lib/gc/Transforms/TilingVector.hpp +++ b/lib/gc/Transforms/TilingVector.hpp @@ -354,7 +354,8 @@ class ForLoopGenerator { /// rewrite operation as vectorize IR in current operation group void rewriteOperationAsVectorize(OpBuilder &rewriter, size_t groupId, - const std::queue *queue = nullptr); + const std::queue *queue = nullptr, + const size_t vectorizeStep = 0); /// Reimplementation of writing a tensor from a constant of denseElementattr. void createNewConstantOp(Operation *srcOp, vector::TransferWriteOp *transferWriteOp, @@ -489,6 +490,8 @@ class LoopGeneratorImpl : public ForLoopGenerator { scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder, const size_t reductionIdx, GenerateLoopHelper &loopHelperParam); + void rectifyParallelIndice(GenerateLoopHelper &loopHelperParam, OpBuilder &b, + Location loc); /// reduction operation parallel axis for loop scf::ForOp parallelAxisGenerateForLoop(OpBuilder &opBuilder, GenerateLoopHelper &loopHelperParam); diff --git a/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir b/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir index fd380ad24..80eb35112 100644 --- a/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir +++ b/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir @@ -10,6 +10,7 @@ // CHECK-DAG: #[[map6:.*]] = affine_map<(d0, d1) -> (d0 floordiv 16 + d1 floordiv 16)> // CHECK-DAG: #[[map7:.*]] = affine_map<()[s0, s1] -> (s0 * 32 + s1)> // CHECK-DAG: #[[map8:.*]] = affine_map<()[s0, s1] -> (s0 * 16 + s1)> +// CHECK-DAG: #[[map9:.*]] = affine_map<(d0, d1) -> (d0 + d1)> @@ -619,7 +620,8 @@ func.func @reduce_fuse_test12(%input: tensor<16x32x64xf32>, // CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg5:.*]] = %[[READ1]]) -> (vector<16xf32>) // CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[CST]]) -> (vector<16xf32>) // CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[arg7]]) -> (vector<16xf32>) -// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}[%[[arg2]], %[[arg6]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map9]](%[[arg2]], %[[arg4]]) +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}[%[[APPLY0]], %[[arg6]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> // CHECK: %[[ADD0:.*]] = arith.addf %[[READ2]], %[[arg9]] : vector<16xf32> // CHECK: %[[REDUCTION:.*]] = vector.reduction , {{.*}} : vector<16xf32> into f32 // CHECK: %[[INSERT:.*]] = vector.insert %[[REDUCTION]], %[[arg5]] [%[[arg4]]] : f32 into vector<16xf32> From aac20a0ad0af253dc09c1b9ad13b1e7a71d4c1e8 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Fri, 20 Sep 2024 13:49:14 +0800 Subject: [PATCH 61/66] simplify code --- lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index 5bd67b367..230978d16 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -44,21 +44,20 @@ static inline bool isReadOrWriteOperation(Operation *op) { } /// which axis do the shape cast in source shape a -void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, +void shapeCastSourceAxis(ArrayRef a, ArrayRef b, SmallVector &res) { unsigned rankA = a.size(); unsigned rankB = b.size(); if (rankA >= rankB) - llvm::llvm_unreachable_internal("May be invalid shape cast operation."); + llvm_unreachable("May be invalid shape cast operation."); auto isOne = [](int64_t v) { return v == 1; }; // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape // casted to a 0-d vector. if (rankA == 0 && all_of(b, isOne)) { - for (size_t i = 0; i < a.size(); i++) { + for (size_t i = 0; i < a.size(); i++) res.emplace_back(i); - } return; } @@ -71,7 +70,7 @@ void shapeCastSourceAxis(const ArrayRef &a, const ArrayRef &b, while (dimB < dimA && j < rankB) dimB *= b[j++]; if (dimA != dimB) { - llvm::llvm_unreachable_internal(" Invalid shape cast operation."); + llvm_unreachable(" Invalid shape cast operation."); break; } if (bAxisBegin != j) { @@ -134,11 +133,11 @@ void getOperationDataAxis(Operation *op, SmallVector &dataAxis) { auto dstType = shapeCastOp.getResultVectorType(); auto srcShape = srcType.getShape(); auto dstShape = dstType.getShape(); - if (srcShape.size() < dstShape.size()) { + if (srcShape.size() < dstShape.size()) shapeCastSourceAxis(srcShape, dstShape, dataAxis); - } else { + else shapeCastSourceAxis(dstShape, srcShape, dataAxis); - } + return; }) .Case([&](vector::BroadcastOp broadcastOp) { From 919dd11cb001a67bf3d1460605df8031dd386a07 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Mon, 23 Sep 2024 14:29:33 +0800 Subject: [PATCH 62/66] update reduction rectify indice code --- include/gc/Transforms/Utils/VectorUtils.h | 31 ++++++++++++ lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 50 ++++++++++--------- lib/gc/Transforms/TilingVector.hpp | 3 +- .../gc/Transforms/cpu-phyaical-register.mlir | 3 +- 4 files changed, 61 insertions(+), 26 deletions(-) diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h index 6b2a2bd0f..c5a6a7cba 100644 --- a/include/gc/Transforms/Utils/VectorUtils.h +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -18,6 +18,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include +#include #include #include @@ -151,6 +152,36 @@ T getInitValForReduce(vector::CombiningKind kind, Type t) { return result; } +template +void getSameBlockTargetOp(Operation *op, + std::queue &candidateOps) { + if (isa(op)) { + candidateOps.push(op); + return; + } + auto getSameBlockSrcOp = [](Operation *trackSrcOp, + std::queue &trackOps, + std::queue &candidateOps) { + for (Value opd : trackSrcOp->getOperands()) { + if (isa(opd) or + opd.getDefiningOp()->getBlock() != trackSrcOp->getBlock()) + continue; + if (isa(opd.getDefiningOp())) + candidateOps.push(opd.getDefiningOp()); + else + trackOps.push(opd.getDefiningOp()); + } + }; + + std::queue trackOps; + getSameBlockSrcOp(op, trackOps, candidateOps); + while (not trackOps.empty()) { + Operation *cadidateOp = trackOps.front(); + trackOps.pop(); + getSameBlockSrcOp(cadidateOp, trackOps, candidateOps); + } +} + } // namespace gc } // namespace mlir diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index d8a5512ab..351cabd68 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -735,33 +735,37 @@ void updateLoopArgsData(Value val, Value originalVal, } void LoopGeneratorImpl::rectifyParallelIndice( - GenerateLoopHelper &loopHelperParam, OpBuilder &b, Location loc) { + GenerateLoopHelper &loopHelperParam, Location loc) { MultiReductionCanonicalizer rdCanonicalizer = getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; SmallVector &reductionAxis = rdCanonicalizer.getReductionAxis(); // rectify indice of read from source operand - auto sourceReadOp = - multireductionOp.getSource().getDefiningOp(); - if (!sourceReadOp) - return; - - AffineExpr outterParallel, innerParallel; - bindDims(multireductionOp->getContext(), outterParallel, innerParallel); - - Value op = - loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - - reductionAxis.size() - 2]; - Value ip = - loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - - reductionAxis.size() - 1]; - Value newIndice = b.createOrFold( - loc, (outterParallel + innerParallel), ValueRange{op, ip}); - int parallelSize = rdCanonicalizer.getParallelAxis().size(); - int readIndiceOffset = - 1 + rdCanonicalizer.getParallelAxis()[parallelSize - 1]; - sourceReadOp->setOperand(readIndiceOffset, newIndice); + std::queue candidateOps; + getSameBlockTargetOp( + multireductionOp.getSource().getDefiningOp(), candidateOps); + while (not candidateOps.empty()) { + auto sourceReadOp = candidateOps.front(); + candidateOps.pop(); + IRRewriter rewriter(sourceReadOp); + rewriter.setInsertionPoint(sourceReadOp); + AffineExpr outterParallel, innerParallel; + bindDims(multireductionOp->getContext(), outterParallel, innerParallel); + + Value op = + loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - + reductionAxis.size() - 2]; + Value ip = + loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - + reductionAxis.size() - 1]; + Value newIndice = rewriter.createOrFold( + loc, (outterParallel + innerParallel), ValueRange{op, ip}); + int parallelSize = rdCanonicalizer.getParallelAxis().size(); + int readIndiceOffset = + 1 + rdCanonicalizer.getParallelAxis()[parallelSize - 1]; + sourceReadOp->setOperand(readIndiceOffset, newIndice); + } } scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop( @@ -901,7 +905,7 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop( loopHelperParam.loopIterArgs = loopState; moveOperationsToCurrentForBody(b, movingOperation, loopHelperParam); if (isLastDimReduction) - rectifyParallelIndice(loopHelperParam, b, loc); + rectifyParallelIndice(loopHelperParam, loc); loopHelperParam.movedOps = &movingOperation; loopHelperParam.candidateOps = &opQueue; @@ -2768,7 +2772,7 @@ void GroupOperationFusionImpl::broadcastFromElements(Operation *op, DenseElementsAttr::get(dataType, constantOp.getValue()), newOperandType); if (failed(res)) - llvm::llvm_unreachable_internal("Wrong to create constant op."); + llvm_unreachable("Wrong to create constant op."); removeOpInCurrentGroups(grpIdx, op, res.value().getDefiningOp()); } else { diff --git a/lib/gc/Transforms/TilingVector.hpp b/lib/gc/Transforms/TilingVector.hpp index 45303a29d..653297941 100644 --- a/lib/gc/Transforms/TilingVector.hpp +++ b/lib/gc/Transforms/TilingVector.hpp @@ -490,8 +490,7 @@ class LoopGeneratorImpl : public ForLoopGenerator { scf::ForOp reductionAxisGenerateForLoop(OpBuilder &opBuilder, const size_t reductionIdx, GenerateLoopHelper &loopHelperParam); - void rectifyParallelIndice(GenerateLoopHelper &loopHelperParam, OpBuilder &b, - Location loc); + void rectifyParallelIndice(GenerateLoopHelper &loopHelperParam, Location loc); /// reduction operation parallel axis for loop scf::ForOp parallelAxisGenerateForLoop(OpBuilder &opBuilder, GenerateLoopHelper &loopHelperParam); diff --git a/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir b/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir index 80eb35112..cf1b69656 100644 --- a/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir +++ b/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir @@ -577,7 +577,8 @@ func.func @reduce_fusePostOp_test11(%input: tensor<16x32x64xf32>, // CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST_0]] {in_bounds = [true]} : tensor<16x32xf32>, vector<16xf32> // CHECK: scf.for %[[arg6:.*]] = %[[C0]] to %[[C16]] step %[[C1]] iter_args(%[[arg7:.*]] = %[[READ0]]) -> (vector<16xf32>) // CHECK: scf.for %[[arg8:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg9:.*]] = %[[CST]]) -> (vector<16xf32>) -// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[arg4]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> +// CHECK: %[[APPLY0:.*]] = affine.apply #[[map9]](%[[arg4]], %[[arg6]]) +// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg2]], %[[APPLY0]], %[[arg8]]], {{.*}} {in_bounds = [true]} : tensor<16x32x64xf32>, vector<16xf32> // CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ1]] : vector<16xf32> // CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[arg9]] : vector<16xf32> // CHECK: %[[REDUCTION:.*]] = vector.reduction , {{.*}} : vector<16xf32> into f32 From 0e7794c4c4f9d7df6329124a0cb9e3fe5ed976c8 Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 24 Sep 2024 09:11:53 +0800 Subject: [PATCH 63/66] use CRTP and type trait to avoid virtual function to improve compile performance --- .../gc/Analysis/VectorBasedFusionAnalysis.h | 20 +++++---- .../gc/Transforms/TilingVector.h | 37 +++++++++++++--- include/gc/Transforms/Utils/VectorUtils.h | 5 +-- lib/gc/Analysis/VectorBasedFusionAnalysis.cpp | 13 +----- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 29 ++++++------ lib/gc/Transforms/Utils/VectorUtils.cpp | 5 +-- ...gister.mlir => cpu-physical-register.mlir} | 44 +++++++++++++++++++ 7 files changed, 106 insertions(+), 47 deletions(-) rename lib/gc/Transforms/TilingVector.hpp => include/gc/Transforms/TilingVector.h (96%) rename test/mlir/test/gc/Transforms/{cpu-phyaical-register.mlir => cpu-physical-register.mlir} (94%) diff --git a/include/gc/Analysis/VectorBasedFusionAnalysis.h b/include/gc/Analysis/VectorBasedFusionAnalysis.h index 4034effbd..a27d613e2 100644 --- a/include/gc/Analysis/VectorBasedFusionAnalysis.h +++ b/include/gc/Analysis/VectorBasedFusionAnalysis.h @@ -10,6 +10,7 @@ #define MLIR_ANALYSIS_VECTORBASEDFUSIONANALYSIS_H #include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Linalgx/Utils.h" #include "gc/Dialect/Microkernel/MicrokernelOps.h" #include "gc/Transforms/Passes.h" #include "gc/Transforms/Utils/VectorUtils.h" @@ -28,8 +29,7 @@ namespace gc { /// record hardware information struct HardWareInfo { - bool favx512f = true; - bool favx2 = true; + size_t vectorWidth = 0; }; /// Vector type conversion helper class @@ -66,6 +66,7 @@ enum class ReturnTypeKind { RT_InGroup, }; +/// Base class of vector-based fusion. class VectorFusionBase { private: @@ -257,16 +258,19 @@ Operation *GroupOperationFusion::getNextTargetOperationInCurrentGroup( while (!tmpOpQueue.empty()) { auto frontOp = tmpOpQueue.front(); - if (isa(frontOp)) { - for (auto x : frontOp->getOperands()) - if (x.getDefiningOp() == curOp) - return frontOp; - } tmpOpQueue.pop(); + if (not isa(frontOp)) + continue; + for (auto x : frontOp->getOperands()) + if (x.getDefiningOp() == curOp) + return frontOp; } return nullptr; } +/// Analysis each operation group class. +/// Currently it will run vector-base fusion, analysis empty group and each +/// operation group's max vectorized step. class GroupOperationAnalysis { private: /// vector-based fusion related data @@ -282,7 +286,7 @@ class GroupOperationAnalysis { void analysisGroupMaxSteps(); /// get fusion strategy GroupOperationFusion &getGroupOperationFusion() { return fusionStrategy; } - + /// running the vector-based fusion void run() { fusionStrategy.run(); } }; } // namespace gc diff --git a/lib/gc/Transforms/TilingVector.hpp b/include/gc/Transforms/TilingVector.h similarity index 96% rename from lib/gc/Transforms/TilingVector.hpp rename to include/gc/Transforms/TilingVector.h index 653297941..d238cf1a4 100644 --- a/lib/gc/Transforms/TilingVector.hpp +++ b/include/gc/Transforms/TilingVector.h @@ -18,6 +18,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/IR/Visitors.h" @@ -120,9 +121,30 @@ struct GenerateLoopHelper { //===----------------------------------------------------------------------===// // vectorize operation class //===----------------------------------------------------------------------===// +class MultiReductionCanonicalizer; +class BroadcastCanonicalizer; +class TransposeCanonicalizer; +class ShapeCastCanonicalizer; + +// fixed extraction trait +template struct SpecialOpTraits; +template <> struct SpecialOpTraits { + using DerivedSpecialT = MultiReductionCanonicalizer; +}; +template <> struct SpecialOpTraits { + using DerivedSpecialT = BroadcastCanonicalizer; +}; +template <> struct SpecialOpTraits { + using DerivedSpecialT = TransposeCanonicalizer; +}; +template <> struct SpecialOpTraits { + using DerivedSpecialT = ShapeCastCanonicalizer; +}; /// base class of special operation template class SpecialOperationCanonicalizer { + using DerivedT = typename SpecialOpTraits::DerivedSpecialT; + private: /// store current special operation SmallVector candidateRdOps; @@ -148,9 +170,12 @@ template class SpecialOperationCanonicalizer { SpecialOperationCanonicalizer(const SmallVector &candidateRdOps, SpecialOperationKind kind, size_t step) : candidateRdOps(candidateRdOps), vectorStep(step), kind(kind) {} - llvm::SmallVector &getCandidateOps(); + SmallVector &getCandidateOps(); virtual ~SpecialOperationCanonicalizer() {} - virtual void prepareSpecialOperationInfo() = 0; + /// call derived speical operation init information methods + void prepareSpecialOperationInfo() { + static_cast(this)->prepareSpecialInfo(); + } /// get kind of speical operation SpecialOperationKind getKind() noexcept { return kind; } /// set current operation group vectorize step @@ -241,7 +266,7 @@ class MultiReductionCanonicalizer /// initalize parallel, reduction axis, reduction operation type and whether /// last dim is reduction axis - void prepareSpecialOperationInfo() override; + void prepareSpecialInfo(); static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == @@ -259,7 +284,7 @@ class BroadcastCanonicalizer : SpecialOperationCanonicalizer( candidateBcOps, SpecialOperationKind::OP_Broadcast, steps){}; virtual ~BroadcastCanonicalizer() noexcept {} - void prepareSpecialOperationInfo() override {} + void prepareSpecialInfo(){}; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_Broadcast; } @@ -278,7 +303,7 @@ class TransposeCanonicalizer : SpecialOperationCanonicalizer( candidateTpOps, SpecialOperationKind::OP_Transpose, steps){}; virtual ~TransposeCanonicalizer() noexcept {} - void prepareSpecialOperationInfo() override{}; + void prepareSpecialInfo(){}; static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose; } @@ -306,7 +331,7 @@ class ShapeCastCanonicalizer : SpecialOperationCanonicalizer( candidateScOps, SpecialOperationKind::OP_ShapeCast, steps){}; virtual ~ShapeCastCanonicalizer() {} - void prepareSpecialOperationInfo() override {} + void prepareSpecialInfo() {} static bool classof(SpecialOperationCanonicalizer *canonicalizer) { return canonicalizer->getKind() == SpecialOperationKind::OP_ShapeCast; } diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h index c5a6a7cba..a7e1ca524 100644 --- a/include/gc/Transforms/Utils/VectorUtils.h +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -54,9 +54,8 @@ namespace gc { /// block. /// insert_slice just move them to the privious of the first operation which /// use it. -void moveSomeInterferenceOperation( - func::FuncOp *func, MLIRContext *ctx, - std::function &conditionalFunc); +void moveOpsFrontOrBack(func::FuncOp *func, MLIRContext *ctx, + std::function &conditionalFunc); /// build a constant operation of index type Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, diff --git a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp index 230978d16..23ff5faa5 100644 --- a/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp +++ b/lib/gc/Analysis/VectorBasedFusionAnalysis.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// #include "gc/Analysis/VectorBasedFusionAnalysis.h" -#include "gc/Dialect/Linalgx/Utils.h" namespace mlir { namespace gc { @@ -397,17 +396,7 @@ int TypeHelper::generateValidSteps(int steps, VectorType type) { // Get the maximum number of current data types that a register can hold [[nodiscard]] int TypeHelper::getDataTypeMAXSIMDLength(VectorType type) { auto typebits = type.getElementTypeBitWidth(); - const int favx512bits = 512; - const int favx2bits = 256; - if (info.favx512f) - return favx512bits / typebits; - - if (info.favx2) - return favx2bits / typebits; - - // invalid hardware - llvm_unreachable("Invalid hardware."); - return -1; + return info.vectorWidth / typebits; } /// Get a appropriate for loop step for current vector type diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 351cabd68..f6223e70b 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -5,7 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "TilingVector.hpp" +#include "gc/Transforms/TilingVector.h" namespace mlir { namespace gc { @@ -1802,7 +1802,7 @@ bool MultiReductionCanonicalizer::hasLastDimReduction() { return res; } -void MultiReductionCanonicalizer::prepareSpecialOperationInfo() { +void MultiReductionCanonicalizer::prepareSpecialInfo() { if (getCandidateOps().empty()) return; @@ -2110,9 +2110,8 @@ void ForLoopGenerator::setOperationCorrectOperand( loopHelperParam .loopIterArgs[loopHelperParam.currentLoopStateIdxMap.at(loopArg)]); } - int offset = isa(op) ? 2 : 1; - if (dyn_cast(op) || - dyn_cast(op)) { + int operandOffset = isa(op) ? 2 : 1; + if (isReadOrWriteOperation(op)) { if (not opPermuationMap.contains(op)) llvm_unreachable("Map must contains operation."); @@ -2133,7 +2132,7 @@ void ForLoopGenerator::setOperationCorrectOperand( } ShapedType tensorType = - cast(op->getOperandTypes()[offset - 1]); + cast(op->getOperandTypes()[operandOffset - 1]); int64_t varIdx = dim; if (tensorType.getRank() > (int64_t)loopHelperParam.inductionVars.size()) { @@ -2146,11 +2145,12 @@ void ForLoopGenerator::setOperationCorrectOperand( } if (loopHelperParam.indiceLoopMap.contains(op)) op->setOperand( - dim + offset, + dim + operandOffset, loopHelperParam .inductionVars[loopHelperParam.indiceLoopMap[op][varIdx]]); else - op->setOperand(dim + offset, loopHelperParam.inductionVars[varIdx]); + op->setOperand(dim + operandOffset, + loopHelperParam.inductionVars[varIdx]); } if (auto readOp = dyn_cast(op)) { size_t grpIdx = getVectorBasedFusion().getOpGroupIndexMap()[op]; @@ -2780,8 +2780,8 @@ void GroupOperationFusionImpl::broadcastFromElements(Operation *op, op->getLoc(), newOperandType, op->getOperands()[0]); removeOpInCurrentGroups(grpIdx, op, bcastOp); std::function candidateFunc = isBroadcastOp; - moveSomeInterferenceOperation(&getGroupOperationFusion().getFunction(), - op->getContext(), candidateFunc); + moveOpsFrontOrBack(&getGroupOperationFusion().getFunction(), + op->getContext(), candidateFunc); } } } @@ -2946,22 +2946,21 @@ struct CPUPhysicalRegisterPass } // affineApply operation is always used by other operations. std::function candidateFunc = isProducerOp; - moveSomeInterferenceOperation(&func, ctx, candidateFunc); + moveOpsFrontOrBack(&func, ctx, candidateFunc); candidateFunc = isCandidateMoveOperations; - moveSomeInterferenceOperation(&func, ctx, candidateFunc); + moveOpsFrontOrBack(&func, ctx, candidateFunc); // canonicalize vector operation, default use vector-based fusion // strategy. HardWareInfo hwInfo; CPUTargetDescriptionAnalysis sysDesc = getAnalysis(); - hwInfo.favx512f = sysDesc.getMaxVectorWidth() >= 512; - hwInfo.favx2 = sysDesc.getMaxVectorWidth() >= 256; + hwInfo.vectorWidth = sysDesc.getMaxVectorWidth(); VectorOperationCanonicalizer canonicalizer( func, hwInfo, CanonicalizerKind::GroupOperations); canonicalizer.run(); candidateFunc = isReadOrWriteOperation; - moveSomeInterferenceOperation(&func, ctx, candidateFunc); + moveOpsFrontOrBack(&func, ctx, candidateFunc); // transpose kernel vector::VectorTransformsOptions transposeOptions = diff --git a/lib/gc/Transforms/Utils/VectorUtils.cpp b/lib/gc/Transforms/Utils/VectorUtils.cpp index e76f78765..c0897eb42 100644 --- a/lib/gc/Transforms/Utils/VectorUtils.cpp +++ b/lib/gc/Transforms/Utils/VectorUtils.cpp @@ -154,9 +154,8 @@ void moveCandidateOperation( // block. // insert_slice just move them to the privious of the first operation which // use it. -void moveSomeInterferenceOperation( - func::FuncOp *func, MLIRContext *ctx, - std::function &conditionalFunc) { +void moveOpsFrontOrBack(func::FuncOp *func, MLIRContext *ctx, + std::function &conditionalFunc) { // Pre-order traversal of each op // Record each operation position. Inorder to we can kown current operation // should move after which operation. diff --git a/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir b/test/mlir/test/gc/Transforms/cpu-physical-register.mlir similarity index 94% rename from test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir rename to test/mlir/test/gc/Transforms/cpu-physical-register.mlir index cf1b69656..ae1b18f4b 100644 --- a/test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir +++ b/test/mlir/test/gc/Transforms/cpu-physical-register.mlir @@ -664,3 +664,47 @@ func.func @add_small_tensor_test14(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) - %2 = linalg.max ins(%1, %cst : tensor<2xf32>, tensor<2xf32>) outs(%0: tensor<2xf32>) -> tensor<2xf32> return %2 : tensor<2xf32> } + +// CHECK-LABEL: func @broadcast_add_test15 +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[arg3:.*]] = {{.*}}) -> (tensor<64x64xf32>) +// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<64x64xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64x64xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[ADD0]], %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<64x64xf32> +func.func @broadcast_add_test15(%arg0: tensor<64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %bcast = linalg.broadcast + ins(%arg0:tensor<64xf32>) + outs(%0:tensor<64x64xf32>) + dimensions = [0] + %out3 = linalg.add ins(%bcast, %arg1: tensor<64x64xf32>, tensor<64x64xf32>) + outs(%arg1: tensor<64x64xf32>) -> tensor<64x64xf32> + return %out3: tensor<64x64xf32> +} + +// CHECK-LABEL: func @broadcast_single_test16 +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: scf.for %[[arg1:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[arg2:.*]] = %[[EMPTY0]]) -> (tensor<64x64xf32>) +// CHECK: scf.for %[[arg3:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg4:.*]] = %[[arg2]]) -> (tensor<64x64xf32>) +// CHECK: %[[READ0:.*]] = vector.transfer_read %arg0[%[[arg3]]], %[[CST]] {in_bounds = [true]} : tensor<64xf32>, vector<16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write %[[READ0]], %[[arg4]][%[[arg1]], %[[arg3]]] {in_bounds = [true]} : vector<16xf32>, tensor<64x64xf32> +func.func @broadcast_single_test16(%arg0: tensor<64xf32>) -> tensor<64x64xf32> { + %0 = tensor.empty() : tensor<64x64xf32> + %bcast = linalg.broadcast + ins(%arg0: tensor<64xf32>) + outs(%0:tensor<64x64xf32>) + dimensions = [0] + return %bcast: tensor<64x64xf32> +} + From 705d249ab0472fa59725bd259ca4673597a0037f Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Wed, 25 Sep 2024 16:53:33 +0800 Subject: [PATCH 64/66] fix comments --- .../gc/Analysis/VectorBasedFusionAnalysis.h | 45 ++-- include/gc/Transforms/TilingVector.h | 27 ++- include/gc/Transforms/Utils/VectorUtils.h | 12 +- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 217 +++++++++--------- lib/gc/Transforms/Utils/VectorUtils.cpp | 208 ++++++++++------- .../gc/Transforms/cpu-physical-register.mlir | 30 +-- 6 files changed, 304 insertions(+), 235 deletions(-) diff --git a/include/gc/Analysis/VectorBasedFusionAnalysis.h b/include/gc/Analysis/VectorBasedFusionAnalysis.h index a27d613e2..1925d4fcb 100644 --- a/include/gc/Analysis/VectorBasedFusionAnalysis.h +++ b/include/gc/Analysis/VectorBasedFusionAnalysis.h @@ -74,19 +74,25 @@ class VectorFusionBase { func::FuncOp func; /// Type helper class, can help us to get operation type TypeHelper typehelper; + /// IR rewriter + IRRewriter *rewriter; public: - VectorFusionBase() = default; - VectorFusionBase(func::FuncOp &func, HardWareInfo &info) - : func(func), typehelper(info) {} - VectorFusionBase(VectorFusionBase &base) - : func(base.getFunction()), typehelper(base.getHardwareInfo()) {} + VectorFusionBase(func::FuncOp &func, HardWareInfo &info, IRRewriter *rewriter) + : func(func), typehelper(info), rewriter(rewriter) {} + VectorFusionBase(VectorFusionBase &base, IRRewriter *rewriter) + : func(base.getFunction()), typehelper(base.getHardwareInfo()), + rewriter(rewriter) {} /// get current function IR func::FuncOp &getFunction() { return func; } /// get current hardware info - HardWareInfo &getHardwareInfo() { return typehelper.getHardwareInfo(); } - TypeHelper &getTypeHelper() { return typehelper; } + HardWareInfo &getHardwareInfo() noexcept { + return typehelper.getHardwareInfo(); + } + TypeHelper &getTypeHelper() noexcept { return typehelper; } + IRRewriter *getRewriter() noexcept { return rewriter; } + void setRewriter(IRRewriter *rewriter) noexcept { this->rewriter = rewriter; } }; /// Group operation fusion strategy class. @@ -132,17 +138,20 @@ class GroupOperationFusion : public VectorFusionBase { DenseMap operandOriginalValue; public: - GroupOperationFusion(func::FuncOp &func, HardWareInfo &info) - : VectorFusionBase(func, info) {} + GroupOperationFusion(func::FuncOp &func, HardWareInfo &info, + IRRewriter *rewriter) + : VectorFusionBase(func, info, rewriter) {} - GroupOperationFusion(GroupOperationFusion &strategy) - : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()), + GroupOperationFusion(GroupOperationFusion &strategy, IRRewriter *rewriter) + : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(), + rewriter), opGroups(strategy.opGroups), groupMaxSteps(strategy.groupMaxSteps), opGroupIndexMap(strategy.opGroupIndexMap), opAnchorPos(strategy.opAnchorPos){}; - GroupOperationFusion(GroupOperationFusion &&strategy) - : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo()), + GroupOperationFusion(GroupOperationFusion &&strategy, IRRewriter *rewriter) + : VectorFusionBase(strategy.getFunction(), strategy.getHardwareInfo(), + rewriter), opGroups(std::move(strategy.opGroups)), groupMaxSteps(std::move(strategy.groupMaxSteps)), groupBigestRankVectorType( @@ -165,9 +174,9 @@ class GroupOperationFusion : public VectorFusionBase { this->getFunction() = fusion.getFunction(); this->getHardwareInfo() = fusion.getHardwareInfo(); this->getTypeHelper() = fusion.getTypeHelper(); + this->setRewriter(fusion.getRewriter()); return *this; }; - GroupOperationFusion &operator=(GroupOperationFusion &&) = default; /// Get the map which contains each group vector type which has biggest /// rank. @@ -275,10 +284,12 @@ class GroupOperationAnalysis { private: /// vector-based fusion related data GroupOperationFusion fusionStrategy; + IRRewriter *rewriter; public: - GroupOperationAnalysis(func::FuncOp &func, HardWareInfo &info) - : fusionStrategy(func, info) {} + GroupOperationAnalysis(func::FuncOp &func, HardWareInfo &info, + IRRewriter *rewriter) + : fusionStrategy(func, info, rewriter), rewriter(rewriter) {} /// remove the useless operation, due to it result is not require by other /// operation void analysisEmptyGroup(); @@ -288,6 +299,8 @@ class GroupOperationAnalysis { GroupOperationFusion &getGroupOperationFusion() { return fusionStrategy; } /// running the vector-based fusion void run() { fusionStrategy.run(); } + /// get current function rewriter + IRRewriter *getRewriter() { return rewriter; } }; } // namespace gc } // namespace mlir diff --git a/include/gc/Transforms/TilingVector.h b/include/gc/Transforms/TilingVector.h index d238cf1a4..90cca7101 100644 --- a/include/gc/Transforms/TilingVector.h +++ b/include/gc/Transforms/TilingVector.h @@ -1,4 +1,4 @@ -//===- TilingVector.hpp - Tiling large vector to small vector ---*- C++ -*-===// +//===- TilingVector.h - Tiling large vector to small vector -----*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -343,12 +343,16 @@ class ShapeCastCanonicalizer class ForLoopGenerator { private: GroupOperationFusion vectorBasedFusion; + IRRewriter *rewriter; public: - ForLoopGenerator(GroupOperationFusion &fusion) : vectorBasedFusion(fusion) {} + ForLoopGenerator(GroupOperationFusion &fusion, IRRewriter *rewriter) + : vectorBasedFusion(fusion, rewriter), rewriter(rewriter) {} virtual ~ForLoopGenerator() noexcept {} + IRRewriter *getRewriter() noexcept { return rewriter; } + void setVectorBaseFusion(GroupOperationFusion &vectorBasedFusion) { this->vectorBasedFusion = vectorBasedFusion; }; @@ -466,7 +470,8 @@ class LoopGeneratorImpl : public ForLoopGenerator { SmallVector shapeCastCanonicalizers; public: - LoopGeneratorImpl(GroupOperationFusion &fusion) : ForLoopGenerator(fusion){}; + LoopGeneratorImpl(GroupOperationFusion &fusion, IRRewriter *rewriter) + : ForLoopGenerator(fusion, rewriter){}; virtual ~LoopGeneratorImpl() noexcept {}; @@ -569,8 +574,9 @@ class GroupOperationFusionImpl : public GroupOperationAnalysis { public: virtual ~GroupOperationFusionImpl() = default; - GroupOperationFusionImpl(func::FuncOp &func, HardWareInfo &info) - : GroupOperationAnalysis(func, info) {} + GroupOperationFusionImpl(func::FuncOp &func, HardWareInfo &info, + IRRewriter *rewriter) + : GroupOperationAnalysis(func, info, rewriter) {} void broadcastFromElements(Operation *op, size_t grpIdx); void scalarOperandFromElements(); @@ -632,17 +638,20 @@ class VectorOperationCanonicalizer { LoopGeneratorImpl loopGenerator; CanonicalizerKind kind; func::FuncOp func; - IRRewriter rewriter; + IRRewriter *rewriter; public: VectorOperationCanonicalizer( - func::FuncOp &func, HardWareInfo &info, + func::FuncOp &func, HardWareInfo &info, IRRewriter *rewriter, CanonicalizerKind kind = CanonicalizerKind::GroupOperations) - : fusion(func, info), loopGenerator(fusion.getGroupOperationFusion()), - kind(kind), rewriter(func) {} + : fusion(func, info, rewriter), + loopGenerator(fusion.getGroupOperationFusion(), rewriter), kind(kind), + rewriter(rewriter) {} virtual ~VectorOperationCanonicalizer() = default; /// run the vector canonicalizer for the IR void run(); + /// get current funtion rewriter + IRRewriter *getRewriter() noexcept { return rewriter; } }; } // namespace gc } // namespace mlir diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h index a7e1ca524..e70d96155 100644 --- a/include/gc/Transforms/Utils/VectorUtils.h +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -24,6 +24,14 @@ namespace mlir { namespace gc { + +enum class OPPRIORITY : uint8_t { + FIRST = 0, + SECOND, + THIRD, + LAST, + OTHERS = 255, +}; /// Need to move some operations like extract_slice or insert_slice. /// Because those operation may interpret our analysis result. e.g.: /// ``` @@ -54,8 +62,8 @@ namespace gc { /// block. /// insert_slice just move them to the privious of the first operation which /// use it. -void moveOpsFrontOrBack(func::FuncOp *func, MLIRContext *ctx, - std::function &conditionalFunc); +void moveOpsFrontOrBack(func::FuncOp *func, IRRewriter &rewriter, + OPPRIORITY start, OPPRIORITY end); /// build a constant operation of index type Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc, diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index f6223e70b..4059e21a3 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -271,13 +271,11 @@ scf::YieldOp maybeYieldValue(OpBuilder &b, Location loc, ValueRange value) { return b.create(loc); } -Operation *createTensorEmptyBefore(Operation *op) { - +Operation *createTensorEmptyBefore(Operation *op, IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); auto rtType = cast(op->getResultTypes()[0]); - IRRewriter reWriter(op); Block *block = op->getBlock(); - - reWriter.setInsertionPoint(block, block->getOperations().begin()); + rewriter.setInsertionPoint(block, block->getOperations().begin()); SmallVector shapes; SmallVector dynDims; @@ -285,16 +283,19 @@ Operation *createTensorEmptyBefore(Operation *op) { shapes.push_back(rtType.getDimSize(i)); if (rtType.isDynamicDim(i)) dynDims.push_back( - reWriter.create(op->getLoc(), op->getResult(0), i)); + rewriter.create(op->getLoc(), op->getResult(0), i)); } - auto emtpyOp = reWriter.create( + auto emtpyOp = rewriter.create( op->getLoc(), rtType.getShape(), rtType.getElementType(), dynDims); return emtpyOp; } /// get the tensor that operation should write into -Value getOperationResultTensor( - Operation *op, DenseMap &visitedOperation) { +Value getOperationResultTensor(Operation *op, + DenseMap &visitedOperation, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + OpResult result = op->getResults()[0]; for (Operation *x : result.getUsers()) { if (not isa(x)) @@ -313,18 +314,21 @@ Value getOperationResultTensor( } LDBG("Result not write back to tensor."); - return createTensorEmptyBefore(op)->getResults()[0]; + return createTensorEmptyBefore(op, rewriter)->getResults()[0]; } -Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { +Operation *createTransferWriteOpAfter(Operation *op, const Value &dest, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + auto rtType = cast(op->getResultTypes()[0]); int64_t rank = rtType.getRank(); auto dstType = cast(dest.getType()); - IRRewriter reWriter(op); + rewriter.setInsertionPoint(op); - auto zero = reWriter.create(op->getLoc(), 0); + auto zero = rewriter.create(op->getLoc(), 0); - reWriter.setInsertionPointAfter(op); + rewriter.setInsertionPointAfter(op); SmallVector inBoundsVal(rank, true); SmallVector shapes; @@ -333,9 +337,9 @@ Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { shapes.push_back(rtType.getDimSize(i)); if (rtType.isDynamicDim(i)) dynDims.push_back( - reWriter.create(op->getLoc(), op->getResult(0), i)); + rewriter.create(op->getLoc(), op->getResult(0), i)); } - return reWriter.create( + return rewriter.create( op->getLoc(), /*vector=*/op->getResult(0), /*source=*/dest, @@ -345,14 +349,13 @@ Operation *createTransferWriteOpAfter(Operation *op, const Value &dest) { Operation *GroupOperationFusionImpl::createTransferReadOpBefore( Operation *op, const Value &operand, vector::TransferReadOp *srcReadOp) { + IRRewriter &rewriter = *getRewriter(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); auto operandType = cast(operand.getType()); - - IRRewriter rewriter(op); - auto zero = - rewriter.create(rewriter.getUnknownLoc(), 0); + auto zero = rewriter.create(op->getLoc(), 0); auto padValue = rewriter.create( - rewriter.getUnknownLoc(), - rewriter.getZeroAttr(operandType.getElementType())); + op->getLoc(), rewriter.getZeroAttr(operandType.getElementType())); if (srcReadOp) { auto resultType = cast(srcReadOp->getType()); @@ -397,9 +400,10 @@ Operation *GroupOperationFusionImpl::createTransferReadOpBefore( // result into the empty tensor [[nodiscard]] std::pair canonicalizeSourceOperation(Operation *op, - DenseMap &visitedOperation) { - auto resultTensor = getOperationResultTensor(op, visitedOperation); - auto writeOp = createTransferWriteOpAfter(op, resultTensor); + DenseMap &visitedOperation, + IRRewriter &rewriter) { + auto resultTensor = getOperationResultTensor(op, visitedOperation, rewriter); + auto writeOp = createTransferWriteOpAfter(op, resultTensor, rewriter); return std::make_pair(resultTensor, writeOp->getResults()[0]); } @@ -657,6 +661,7 @@ void ForLoopGenerator::movePreOpToCurrentAnchor( void ForLoopGenerator::movePostOpToCurrentAnchor( OpBuilder &b, GenerateLoopHelper &loopHelperParam) { + OpBuilder::InsertionGuard g(b); std::queue movingOperations; // 1. get post-op to current loop bod @@ -669,8 +674,7 @@ void ForLoopGenerator::movePostOpToCurrentAnchor( moveOperationsToCurrentForBody(b, movingOperations, loopHelperParam); // 4. replace correct for loop result to post-op - IRRewriter rewriter(b); - replaceOperationsWithForLoopResult(rewriter, movingOperations, + replaceOperationsWithForLoopResult(*getRewriter(), movingOperations, loopHelperParam); // 5. move operations to moved queue @@ -683,6 +687,7 @@ void ForLoopGenerator::movePostOpToCurrentAnchor( void ForLoopGenerator::generateLoopResults( OpBuilder &b, const Location &loc, GenerateLoopHelper &loopHelperParam, DenseMap &nextOperandIdxMap) { + OpBuilder::InsertionGuard g(b); SmallVector results; DenseMap currentResultMap; getResultInCurrentOps(loopHelperParam.anchorIdx, loopHelperParam.groupIdx, @@ -736,6 +741,7 @@ void updateLoopArgsData(Value val, Value originalVal, void LoopGeneratorImpl::rectifyParallelIndice( GenerateLoopHelper &loopHelperParam, Location loc) { + OpBuilder::InsertionGuard g(*getRewriter()); MultiReductionCanonicalizer rdCanonicalizer = getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; auto &multireductionOp = rdCanonicalizer.getCandidateOps()[0]; @@ -748,8 +754,7 @@ void LoopGeneratorImpl::rectifyParallelIndice( while (not candidateOps.empty()) { auto sourceReadOp = candidateOps.front(); candidateOps.pop(); - IRRewriter rewriter(sourceReadOp); - rewriter.setInsertionPoint(sourceReadOp); + getRewriter()->setInsertionPoint(sourceReadOp); AffineExpr outterParallel, innerParallel; bindDims(multireductionOp->getContext(), outterParallel, innerParallel); @@ -759,7 +764,7 @@ void LoopGeneratorImpl::rectifyParallelIndice( Value ip = loopHelperParam.inductionVars[loopHelperParam.inductionVars.size() - reductionAxis.size() - 1]; - Value newIndice = rewriter.createOrFold( + Value newIndice = getRewriter()->createOrFold( loc, (outterParallel + innerParallel), ValueRange{op, ip}); int parallelSize = rdCanonicalizer.getParallelAxis().size(); int readIndiceOffset = @@ -794,9 +799,6 @@ scf::ForOp LoopGeneratorImpl::reductionAxisGenerateForLoop( ? loopStep : 1; - func::FuncOp func = fusionStrategy.getFunction(); - IRRewriter rewriterOfFunc(func); - Value zero = makeIndexArithConstantOp(opBuilder, loc, 0); Value forSteps = makeIndexArithConstantOp(opBuilder, loc, loopStep); Value numIter = makeIndexArithConstantOp( @@ -994,14 +996,13 @@ void LoopGeneratorImpl::ensureAccInParallelLoop( /// This function also call reduction axis for loop scf::ForOp LoopGeneratorImpl::parallelAxisGenerateForLoop( OpBuilder &opBuilder, GenerateLoopHelper &loopHelperParam) { + OpBuilder::InsertionGuard g(opBuilder); MultiReductionCanonicalizer &rdCanonicalizer = getMultiRdCanonicalizers()[loopHelperParam.groupIdx]; vector::MultiDimReductionOp &multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; VectorType vectorType = rdCanonicalizer.getSourceType(); GroupOperationFusion &fusionStrategy = getVectorBasedFusion(); - func::FuncOp func = fusionStrategy.getFunction(); - IRRewriter rewriterOfFunc(func); SmallVector ¶llelAxis = rdCanonicalizer.getParallelAxis(); const Location &loc = multiReductionOp.getLoc(); @@ -1326,7 +1327,7 @@ void ForLoopGenerator::replaceOpUsersWithForLoopResult( scf::ForOp forOp, int grpIdx, SmallVector &nextAnchorResults, DenseMap &nextAnchorResultsIdxMap, DenseMap &forResultOrignalResultMap) { - IRRewriter rewriter(forOp); + DenseSet forOpChildOps; forOp->walk([&](Operation *op) { forOpChildOps.insert(op); }); auto replaceIfFn = [&](OpOperand &use) { @@ -1337,13 +1338,13 @@ void ForLoopGenerator::replaceOpUsersWithForLoopResult( Value forResult = forOp->getResults()[nextAnchorResultsIdxMap[x]]; // subsequent group must use the replaced result as operand rectifyGroupOperands(grpIdx, originalResult, forResult); - rewriter.replaceOpUsesWithIf(originalResult.getDefiningOp(), forResult, - replaceIfFn); + getRewriter()->replaceOpUsesWithIf(originalResult.getDefiningOp(), + forResult, replaceIfFn); } } scf::ForOp LoopGeneratorImpl::generateMultiReductionForLoop(const size_t grpIdx) { - + OpBuilder::InsertionGuard g(*getRewriter()); DenseMap> indiceLoopMap; rearrageMultiReductionIR(grpIdx, indiceLoopMap); // get current loop init args @@ -1354,18 +1355,17 @@ LoopGeneratorImpl::generateMultiReductionForLoop(const size_t grpIdx) { MultiReductionCanonicalizer &rdCanonicalizer = getMultiRdCanonicalizers()[grpIdx]; - OpBuilder opBuilder(rdCanonicalizer.getCandidateOps()[0]); + getRewriter()->setInsertionPoint(rdCanonicalizer.getCandidateOps()[0]); loopHelper.indiceLoopMap = indiceLoopMap; - scf::ForOp forOp = parallelAxisGenerateForLoop(opBuilder, loopHelper); + scf::ForOp forOp = parallelAxisGenerateForLoop(*getRewriter(), loopHelper); replaceOpUsersWithForLoopResult(forOp, grpIdx, loopHelper.nextAnchorResults, loopHelper.nextAnchorResultsIdxMap, loopHelper.nextAnchorResultOrignalResultMap); vector::MultiDimReductionOp multiReductionOp = rdCanonicalizer.getCandidateOps()[0]; - IRRewriter rewriter(multiReductionOp); - rewriter.eraseOp(multiReductionOp); + getRewriter()->eraseOp(multiReductionOp); return forOp; } @@ -1625,7 +1625,7 @@ void ForLoopGenerator::rectifyReadOperationIndice( /// generate transpose for loop scf::ForOp LoopGeneratorImpl::generateShapeCastForLoop(const size_t grpIdx) { - + OpBuilder::InsertionGuard g(*getRewriter()); ShapeCastCanonicalizer &scCanonicalizer = getShapeCastCanonicalizers()[grpIdx]; vector::ShapeCastOp &scOp = scCanonicalizer.getCandidateOps()[0]; @@ -1647,7 +1647,7 @@ scf::ForOp LoopGeneratorImpl::generateShapeCastForLoop(const size_t grpIdx) { iterArgs.emplace_back(successorWriteOp->getOperands()[1]); SmallVector inductionVars; - IRRewriter rewriter(scOp); + getRewriter()->setInsertionPoint(scOp); const size_t groupStep = getVectorBasedFusion().getGroupMaxSteps()[grpIdx]; bool isSourceMultiple = @@ -1667,9 +1667,9 @@ scf::ForOp LoopGeneratorImpl::generateShapeCastForLoop(const size_t grpIdx) { inductionVars, iterArgs); } for (auto [idx, successorWriteOp] : enumerate(successorWriteOps)) - rewriter.replaceOp(successorWriteOp, forOp->getResults()[idx]); + getRewriter()->replaceOp(successorWriteOp, forOp->getResults()[idx]); - rewriter.eraseOp(scOp); + getRewriter()->eraseOp(scOp); clearCurrentOperationGroup(grpIdx); return forOp; } @@ -1698,12 +1698,11 @@ void ForLoopGenerator::clearCurrentOperationGroup(size_t grpIdx) { }; scf::ForOp LoopGeneratorImpl::generateTransposeForLoop(const size_t grpIdx) { - + OpBuilder::InsertionGuard g(*getRewriter()); // transpose rank must bigger than 2 TransposeCanonicalizer &tpCanonicalizer = getTransposeCanonicalizers()[grpIdx]; vector::TransposeOp &tpOp = tpCanonicalizer.getCandidateOps()[0]; - IRRewriter rewriter(tpOp); VectorType vtType = tpOp.getResultVectorType(); size_t rank = vtType.getRank(); @@ -1730,16 +1729,16 @@ scf::ForOp LoopGeneratorImpl::generateTransposeForLoop(const size_t grpIdx) { GenerateLoopHelper loopHelper(grpIdx, 0); prepareForLoopArgs(grpIdx, loopHelper); - OpBuilder b(tpOp); + getRewriter()->setInsertionPoint(tpOp); int tpStep = TransposeCanonicalizer::TRANSPOSE_KERNEL::KERNEL_16X16; // only contains last dim can use fast transpose algorithm if ((tpCanonicalizer.getFirstTpIdx() == (rank - 1) or tpCanonicalizer.getSecondTpIdx() == (rank - 1)) and isTwoDTranspose) { scf::ForOp forOp = generateTransposeForLoopWithLastDim( - b, tpStep, tpOp.getLoc(), successorWriteOp, loopHelper); + *getRewriter(), tpStep, tpOp.getLoc(), successorWriteOp, loopHelper); - rewriter.replaceOp(successorWriteOp, forOp); + getRewriter()->replaceOp(successorWriteOp, forOp); // clear current group operation clearCurrentOperationGroup(grpIdx); return forOp; @@ -1751,10 +1750,10 @@ scf::ForOp LoopGeneratorImpl::generateTransposeForLoop(const size_t grpIdx) { itrIdx++; } // scalar data movement - scf::ForOp forOp = generateTransposeScalarDataMovement(b, tpOp.getLoc(), - tpAxisMap, loopHelper); + scf::ForOp forOp = generateTransposeScalarDataMovement( + *getRewriter(), tpOp.getLoc(), tpAxisMap, loopHelper); - rewriter.replaceOp(successorWriteOp, forOp); + getRewriter()->replaceOp(successorWriteOp, forOp); clearCurrentOperationGroup(grpIdx); return forOp; } @@ -2086,10 +2085,11 @@ void VectorOperationCanonicalizer::run() { // 3. Some IR cleanup work DominanceInfo domInfo; eliminateCommonSubExpressions( - rewriter, domInfo, loopGenerator.getVectorBasedFusion().getFunction()); + *getRewriter(), domInfo, + loopGenerator.getVectorBasedFusion().getFunction()); } else { // TODO: need to add directly canonicalize operations logic - // generateGroupOpVectorizedIR(idx, grp, fusionStrategy.opGroupIndexMap); + llvm_unreachable("Currently not support directly canonicalize operations."); } } @@ -2168,6 +2168,7 @@ void ForLoopGenerator::setOperationCorrectOperand( scf::ForOp ForLoopGenerator::constructNestedForOp( const size_t groupIdx, OpBuilder &b, const Location &loc, ArrayRef dims, GenerateLoopHelper &loopHelper) { + OpBuilder::InsertionGuard g(b); const int loop_step = getVectorBasedFusion().getGroupMaxSteps()[groupIdx]; // loop initialization variable auto zero = makeIndexArithConstantOp(b, loc, 0); @@ -2249,11 +2250,12 @@ scf::ForOp ForLoopGenerator::constructNestedForOp( } Value setOutGroupOperationOperandResult(Operation *op, - const VectorType &newOperandType) { + const VectorType &newOperandType, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); auto ret = TypeSwitch(op) .Case([&](arith::ConstantOp constantOp) { - IRRewriter rewriter(op); rewriter.setInsertionPointAfter(op); Type resultElementType = newOperandType.getElementType(); auto value = constantOp.getValue(); @@ -2293,12 +2295,14 @@ Value setOutGroupOperationOperandResult(Operation *op, } void setOperationOperandResult(Operation *op, const VectorType &newOperandType, - const DenseMap &opMap) { + const DenseMap &opMap, + IRRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); for (auto [idx, x] : llvm::enumerate(op->getOperands())) { if (dyn_cast(x.getType())) { if (!opMap.contains(x.getDefiningOp())) { - auto result = setOutGroupOperationOperandResult(x.getDefiningOp(), - newOperandType); + auto result = setOutGroupOperationOperandResult( + x.getDefiningOp(), newOperandType, rewriter); op->setOperand(idx, result); } else { x.setType(newOperandType); @@ -2313,10 +2317,12 @@ void setOperationOperandResult(Operation *op, const VectorType &newOperandType, void ForLoopGenerator::createNewConstantOp( Operation *srcOp, vector::TransferWriteOp *transferWriteOp, size_t groupSteps) { + IRRewriter &rewriter = *getRewriter(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(srcOp); DenseMap &opPermuationMap = getVectorBasedFusion().getOpPermuationMap(); - IRRewriter srcWriter(srcOp); VectorType newOperandType = getVectorBasedFusion().getTypeHelper().getVectorzedType( cast(srcOp), groupSteps); @@ -2326,15 +2332,14 @@ void ForLoopGenerator::createNewConstantOp( auto valueType = dyn_cast(srcConstantOp.getValue()); if (valueType.isSplat()) { FailureOr res = createArithSplatConstantOp( - srcWriter, srcOp->getLoc(), valueType, newOperandType); + rewriter, srcOp->getLoc(), valueType, newOperandType); if (failed(res)) { - llvm::llvm_unreachable_internal("Wrong to create constant op."); + llvm_unreachable("Wrong to create constant op."); } newConstantOp = res.value().getDefiningOp(); } else { // TODO: need to test not splat value - llvm::llvm_unreachable_internal( - "Can't support not splat constant value."); + llvm_unreachable("Can't support not splat constant value."); } newConstantOp->getResult(0).setType(newOperandType); @@ -2342,13 +2347,12 @@ void ForLoopGenerator::createNewConstantOp( opPermuationMap.insert( {*transferWriteOp, transferWriteOp->getPermutationMap()}); setOpVectorizationPermutationMap( - *transferWriteOp, srcWriter, + *transferWriteOp, *getRewriter(), cast(transferWriteOp->getResults()[0].getType()), transferWriteOp->getPermutationMap()); return; } - llvm::llvm_unreachable_internal( - "Can't support not DenseElementsAttr constant."); + llvm_unreachable("Can't support not DenseElementsAttr constant."); } /// Rewrite the operations in the group to vectorized form. @@ -2376,8 +2380,6 @@ void ForLoopGenerator::rewriteOperationAsVectorize( TypeSwitch(op) .Case( [&](vector::TransferWriteOp transferWriteOp) { - IRRewriter rewriter(transferWriteOp); - Operation *srcOp = transferWriteOp->getOperand(0).getDefiningOp(); if (isa(srcOp)) { @@ -2411,8 +2413,7 @@ void ForLoopGenerator::rewriteOperationAsVectorize( .Case( [&](vector::MultiDimReductionOp multiReductionOp) { multiReductionOp.dump(); - llvm::llvm_unreachable_internal( - "It should not appear this operation."); + llvm_unreachable("It should not appear this operation."); return failure(); }) .Case([&](Operation *extfOp) { @@ -2422,11 +2423,11 @@ void ForLoopGenerator::rewriteOperationAsVectorize( .Default([&](Operation *op) { if (isSpecialOp(op)) { op->dump(); - llvm::llvm_unreachable_internal( - "It should not appear this operation."); + llvm_unreachable("It should not appear this operation."); return failure(); } - setOperationOperandResult(op, newOperandType, opMap); + setOperationOperandResult(op, newOperandType, opMap, + *getRewriter()); return success(); }); if (failed(lowerResult)) { @@ -2456,16 +2457,7 @@ void GroupOperationFusionImpl::removeOpInCurrentGroups(size_t grpIdx, // erase and replace the operation SmallVector usesOp(op->getUsers().begin(), op->getUsers().end()); - IRRewriter rewriter(op); - rewriter.replaceOp(op, replacedOp); - // update removed operation related operation anchor position - // getGroupOperationFusion().getOpAnchorPos()[replacedOp] = - // getOperationMaxVectorType(replacedOp)->getRank() - 1; - // for (Operation *x : usesOp) - // getGroupOperationFusion().getOpAnchorPos()[x] = - // getOperationMaxVectorType(x)->getRank() - 1; - - // updateOpGroupInfo(grpIdx); + getRewriter()->replaceOp(op, replacedOp); } void GroupOperationFusionImpl::updateOpGroupInfo(size_t grpIdx) { @@ -2534,7 +2526,7 @@ void GroupOperationFusionImpl::generateEmptyTensorAndWrite( size_t sourceOpGid = opGroupIndexMap[sourceOp]; auto [tsr, writeOpresult] = - canonicalizeSourceOperation(sourceOp, visitedOperation); + canonicalizeSourceOperation(sourceOp, visitedOperation, *getRewriter()); auto writeOp = writeOpresult.getDefiningOp(); srcOpCanoniclizedMap.insert({sourceOp, {tsr, writeOpresult}}); updateOpOperandResultInGroups(sourceOpGid, sourceOp, tsr, writeOpresult); @@ -2550,7 +2542,6 @@ void GroupOperationFusionImpl::generateEmptyTensorAndWrite( void GroupOperationFusionImpl::specialOperationRectify( DenseMap &visitedOperation) { auto &opGroups = getGroupOperationFusion().getOpGroups(); - IRRewriter rewriter(getGroupOperationFusion().getFunction()); for (auto [idx, grp] : llvm::enumerate(opGroups)) { std::queue tmpQueue(grp); @@ -2570,7 +2561,7 @@ void GroupOperationFusionImpl::specialOperationRectify( getGroupOperationFusion().getOpAnchorPos()[srcOp] = getGroupOperationFusion().getOpAnchorPos()[op]; - rewriter.replaceOp(op, srcOp); + getRewriter()->replaceOp(op, srcOp); continue; } // anchor of multidim reduction rectify @@ -2613,6 +2604,7 @@ void GroupOperationFusionImpl::updateReturnResultKind(Operation *sourceOp, void GroupOperationFusionImpl::replaceConstantOpAsNewOp(Operation *op, Operation *sourceOp, size_t operandIdx) { + OpBuilder::InsertionGuard g(*getRewriter()); DenseMap &opGroupIndexMap = getGroupOperationFusion().getOpGroupIndexMap(); if (!opGroupIndexMap.contains(op)) { @@ -2626,7 +2618,8 @@ void GroupOperationFusionImpl::replaceConstantOpAsNewOp(Operation *op, if (isa(op)) { if (operandIdx == 1) { // accumulate value, just empty tensor is okay - auto resultTensor = getOperationResultTensor(sourceOp, visitedOperation); + auto resultTensor = + getOperationResultTensor(sourceOp, visitedOperation, *getRewriter()); auto opInit = canonicalizeCurrentOperation(op, resultTensor, operandIdx); updateOpOperandResultInGroups(opGroupIndexMap[op], op, opInit); return; @@ -2637,7 +2630,7 @@ void GroupOperationFusionImpl::replaceConstantOpAsNewOp(Operation *op, } auto constantOp = cast(sourceOp); - IRRewriter rewriter(constantOp); + getRewriter()->setInsertionPoint(constantOp); size_t groupSteps = getGroupOperationFusion().getGroupMaxSteps()[opGroupIndexMap[op]]; @@ -2648,9 +2641,9 @@ void GroupOperationFusionImpl::replaceConstantOpAsNewOp(Operation *op, auto valueType = cast(constantOp.getValue()); if (valueType.isSplat()) { FailureOr res = createArithSplatConstantOp( - rewriter, constantOp->getLoc(), valueType, newOperandType); + *getRewriter(), constantOp->getLoc(), valueType, newOperandType); if (failed(res)) - llvm::llvm_unreachable_internal("Wrong to create constant op."); + llvm_unreachable("Wrong to create constant op."); op->setOperand(operandIdx, res.value()); // transfer read operation just use the constant value to do @@ -2660,7 +2653,7 @@ void GroupOperationFusionImpl::replaceConstantOpAsNewOp(Operation *op, op->getOperand(0).getDefiningOp()); return; } - llvm::llvm_unreachable_internal("Can't support not splat constant value."); + llvm_unreachable("Can't support not splat constant value."); } } @@ -2752,13 +2745,14 @@ void GroupOperationFusionImpl::GroupOperationReturnResultProcess( void GroupOperationFusionImpl::broadcastFromElements(Operation *op, size_t grpIdx) { + OpBuilder::InsertionGuard g(*getRewriter()); if (not isa(op)) llvm_unreachable("Must be broadcast operation."); if (not isa(op->getOperandTypes()[0])) { auto inputBcastOp = cast(op); size_t steps = getGroupOperationFusion().getGroupMaxSteps()[grpIdx]; - IRRewriter rewriter(op); + getRewriter()->setInsertionPoint(op); VectorType newOperandType = getGroupOperationFusion().getTypeHelper().getVectorzedType(op, steps); if (isa_and_nonnull(op->getOperand(0).getDefiningOp())) { @@ -2768,7 +2762,7 @@ void GroupOperationFusionImpl::broadcastFromElements(Operation *op, shapes, inputBcastOp.getResultVectorType().getElementType()); FailureOr res = createArithSplatConstantOp( - rewriter, op->getLoc(), + *getRewriter(), op->getLoc(), DenseElementsAttr::get(dataType, constantOp.getValue()), newOperandType); if (failed(res)) @@ -2776,12 +2770,12 @@ void GroupOperationFusionImpl::broadcastFromElements(Operation *op, removeOpInCurrentGroups(grpIdx, op, res.value().getDefiningOp()); } else { - auto bcastOp = rewriter.create( + auto bcastOp = getRewriter()->create( op->getLoc(), newOperandType, op->getOperands()[0]); removeOpInCurrentGroups(grpIdx, op, bcastOp); std::function candidateFunc = isBroadcastOp; moveOpsFrontOrBack(&getGroupOperationFusion().getFunction(), - op->getContext(), candidateFunc); + *getRewriter(), OPPRIORITY::THIRD, OPPRIORITY::THIRD); } } } @@ -2816,7 +2810,6 @@ void GroupOperationFusionImpl::canonicalizeEachOperationGroup() { DenseMap &OpAnchorPos = getGroupOperationFusion().getOpAnchorPos(); func::FuncOp func = getGroupOperationFusion().getFunction(); - IRRewriter rewriter(func); analysisGroupMaxSteps(); @@ -2881,6 +2874,7 @@ void ForLoopGenerator::rectifyGroupOperands(size_t currentGroupId, mlir::FailureOr ForLoopGenerator::generateVectorizedForLoop( const size_t groupId, IRRewriter &rewriter, VectorType vectorType) { + OpBuilder::InsertionGuard g(rewriter); // prepare for loop iterargs GenerateLoopHelper loopHelper(groupId, 0); prepareForLoopArgs(groupId, loopHelper); @@ -2906,6 +2900,8 @@ bool LoopGeneratorImpl::isGroupHasSpecialOperation(const size_t grpIdx) { } void LoopGeneratorImpl::generateGroupOpVectorizedIR(const int idx) { + OpBuilder::InsertionGuard g(*getRewriter()); + auto &grp = getVectorBasedFusion().getOpGroups()[idx]; if (grp.empty()) { LDBG("Current operation Group is empty."); @@ -2917,12 +2913,10 @@ void LoopGeneratorImpl::generateGroupOpVectorizedIR(const int idx) { VectorType groupType = getVectorBasedFusion().getGroupBiggestRankVectorType()[idx]; - IRRewriter rewriter(grp.back()); - rewriter.setInsertionPointAfter(grp.back()); // 1. Rewrite operation as vectorized form // 2. Generate loop - // rewriteOperationAsVectorize(rewriter, idx); - auto forOp = generateVectorizedForLoop(idx, rewriter, groupType); + getRewriter()->setInsertionPointAfter(grp.back()); + auto forOp = generateVectorizedForLoop(idx, *getRewriter(), groupType); // special operation do not need to change anything if (failed(forOp)) return; @@ -2939,16 +2933,12 @@ struct CPUPhysicalRegisterPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); auto func = getOperation(); + IRRewriter rewriter(func); if (hasNotSupportOperation(&func)) { LDBG("Not support operation appears in current function."); return; } - // affineApply operation is always used by other operations. - std::function candidateFunc = isProducerOp; - moveOpsFrontOrBack(&func, ctx, candidateFunc); - candidateFunc = isCandidateMoveOperations; - moveOpsFrontOrBack(&func, ctx, candidateFunc); // canonicalize vector operation, default use vector-based fusion // strategy. HardWareInfo hwInfo; @@ -2956,11 +2946,14 @@ struct CPUPhysicalRegisterPass getAnalysis(); hwInfo.vectorWidth = sysDesc.getMaxVectorWidth(); VectorOperationCanonicalizer canonicalizer( - func, hwInfo, CanonicalizerKind::GroupOperations); + func, hwInfo, &rewriter, CanonicalizerKind::GroupOperations); + + // affineApply operation is always used by other operations. + moveOpsFrontOrBack(&func, rewriter, OPPRIORITY::FIRST, OPPRIORITY::SECOND); + canonicalizer.run(); - candidateFunc = isReadOrWriteOperation; - moveOpsFrontOrBack(&func, ctx, candidateFunc); + moveOpsFrontOrBack(&func, rewriter, OPPRIORITY::LAST, OPPRIORITY::LAST); // transpose kernel vector::VectorTransformsOptions transposeOptions = diff --git a/lib/gc/Transforms/Utils/VectorUtils.cpp b/lib/gc/Transforms/Utils/VectorUtils.cpp index c0897eb42..0f148d41c 100644 --- a/lib/gc/Transforms/Utils/VectorUtils.cpp +++ b/lib/gc/Transforms/Utils/VectorUtils.cpp @@ -17,22 +17,25 @@ namespace gc { #define SAFE_EXPAND(X) X #define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") -static inline void moveOpBeginingOfBlock(Operation *op) { +static inline void moveOpBeginingOfBlock(Operation *op, IRRewriter &rewriter) { Block *block = op->getBlock(); if (block->getOperations().empty()) llvm_unreachable("Emtpy block."); if (&block->front() == op) return; - op->moveBefore(&block->front()); + rewriter.moveOpAfter(op, op->getBlock(), op->getBlock()->begin()); } -LogicalResult -moveFront(Operation *op, - llvm::DenseMap &operationPosition) { - IRRewriter rewriter(op); - Operation *backOperation; - size_t pos = 0; +// Special behavior for ++OPPRIORITY +OPPRIORITY operator++(OPPRIORITY &c) { + using IntType = typename std::underlying_type::type; + c = static_cast(static_cast(c) + 1); + return c; +} + +LogicalResult moveFront(Operation *op, IRRewriter &rewriter) { + Operation *backOperation = nullptr; // check all the operand is block argument bool allBlockArgs = true; for (auto operand : op->getOperands()) { @@ -42,7 +45,7 @@ moveFront(Operation *op, } } if (allBlockArgs) { - moveOpBeginingOfBlock(op); + moveOpBeginingOfBlock(op, rewriter); return success(); } for (auto operand : op->getOperands()) { @@ -50,85 +53,138 @@ moveFront(Operation *op, continue; Operation *sourceOp = operand.getDefiningOp(); - if (operationPosition[sourceOp] > pos and - sourceOp->getBlock() == op->getBlock()) { + if (sourceOp->getBlock() != op->getBlock()) + continue; + if (not backOperation) { backOperation = sourceOp; - pos = operationPosition[sourceOp]; + continue; } + + if (backOperation->isBeforeInBlock(sourceOp)) + backOperation = sourceOp; } - if (pos == 0) { + if (not backOperation) { // extract operand operation all in previous block - moveOpBeginingOfBlock(op); - return success(); - } - if (backOperation) { - rewriter.moveOpAfter(op, backOperation); + moveOpBeginingOfBlock(op, rewriter); return success(); } - return failure(); + rewriter.moveOpAfter(op, backOperation); + return success(); } -LogicalResult moveBack(Operation *op, - llvm::DenseMap &operationPosition) { - IRRewriter rewriter(op); - Operation *firstOperation; - size_t pos = std::numeric_limits::max(); +LogicalResult moveBack(Operation *op, IRRewriter &rewriter) { + Operation *firstOperation = nullptr; for (auto user : op->getUsers()) { - if (operationPosition[user] < pos and user->getBlock() == op->getBlock()) { + if (user->getBlock() != op->getBlock()) + continue; + if (not firstOperation) { firstOperation = user; - pos = operationPosition[user]; + continue; } + if (user->isBeforeInBlock(firstOperation)) + firstOperation = user; } - if (pos == std::numeric_limits::max()) { + if (not firstOperation) { // Don't move. // TODO: need to consider move before the block which use it. return success(); } - if (firstOperation) { - rewriter.moveOpBefore(op, firstOperation); - return success(); - } - return failure(); + rewriter.moveOpBefore(op, firstOperation); + return success(); } void moveCandidateOperation( - llvm::DenseMap &operationPosition, - ArrayRef candidateOps) { - - for (Operation *op : candidateOps) { - auto ret = - TypeSwitch(op) - .Case([&](affine::AffineApplyOp affineOp) { - return moveFront(op, operationPosition); - }) - .Case( - [&](tensor::ExtractSliceOp extractOp) { - return moveFront(op, operationPosition); - }) - .Case([&](tensor::EmptyOp emptyOp) { - return moveFront(op, operationPosition); - }) - .Case([&](tensor::InsertSliceOp insertOp) { - return moveBack(op, operationPosition); - }) - .Case([&](vector::TransferReadOp readOp) { - return moveFront(op, operationPosition); - }) - .Case( - [&](vector::TransferWriteOp writeOp) { - return moveBack(op, operationPosition); - }) - .Case([&](vector::BroadcastOp bcOp) { - return moveFront(op, operationPosition); - }) - .Default([&](Operation *op) { return success(); }); - if (failed(ret)) { - LDBG("Wrong to move operation:" << *op << "\n"); - return; + std::queue> &candidateOps, + IRRewriter &rewriter, OPPRIORITY start, OPPRIORITY end) { + std::queue> remainOps; + OPPRIORITY itrBegin = start; + while (not remainOps.empty() or not candidateOps.empty()) { + while (not candidateOps.empty()) { + std::pair cur = candidateOps.front(); + candidateOps.pop(); + if (cur.second < start or cur.second > end) + continue; + if (cur.second != itrBegin) { + remainOps.push(cur); + continue; + } + + Operation *op = cur.first; + auto ret = + TypeSwitch(op) + .Case([&](affine::AffineApplyOp affineOp) { + return moveFront(op, rewriter); + }) + .Case( + [&](tensor::ExtractSliceOp extractOp) { + return moveFront(op, rewriter); + }) + .Case([&](tensor::EmptyOp emptyOp) { + return moveFront(op, rewriter); + }) + .Case([&](tensor::InsertSliceOp insertOp) { + return moveBack(op, rewriter); + }) + .Case([&](vector::TransferReadOp readOp) { + return moveFront(op, rewriter); + }) + .Case( + [&](vector::TransferWriteOp writeOp) { + return moveBack(op, rewriter); + }) + .Case([&](vector::BroadcastOp bcOp) { + return moveFront(op, rewriter); + }) + .Default([&](Operation *op) { return success(); }); + if (failed(ret)) { + LDBG("Wrong to move operation:" << *op << "\n"); + return; + } } + candidateOps.swap(remainOps); + ++itrBegin; } } +// Get operation priority +void getOperationPriority( + func::FuncOp *func, + std::queue> &candidateOps) { + // get the position of each operation + func->walk([&](Operation *op) { + TypeSwitch(op) + .Case([&](affine::AffineApplyOp affineOp) { + candidateOps.push(std::make_pair(op, OPPRIORITY::FIRST)); + return; + }) + .Case([&](tensor::ExtractSliceOp extractOp) { + candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND)); + return; + }) + .Case([&](tensor::EmptyOp emptyOp) { + candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND)); + return; + }) + .Case([&](tensor::InsertSliceOp insertOp) { + candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND)); + return; + }) + .Case([&](vector::TransferReadOp readOp) { + candidateOps.push(std::make_pair(op, OPPRIORITY::LAST)); + return; + }) + .Case([&](vector::TransferWriteOp writeOp) { + candidateOps.push(std::make_pair(op, OPPRIORITY::LAST)); + return; + }) + .Case([&](vector::BroadcastOp bcOp) { + candidateOps.push(std::make_pair(op, OPPRIORITY::THIRD)); + return; + }) + .Default([&](Operation *op) { return; }); + }); +} + // Need to move some operations like extract_slice or insert_slice. // Because those operation may interpret our analysis result. e.g.: // ``` @@ -154,24 +210,14 @@ void moveCandidateOperation( // block. // insert_slice just move them to the privious of the first operation which // use it. -void moveOpsFrontOrBack(func::FuncOp *func, MLIRContext *ctx, - std::function &conditionalFunc) { +void moveOpsFrontOrBack(func::FuncOp *func, IRRewriter &rewriter, + OPPRIORITY start, OPPRIORITY end) { // Pre-order traversal of each op - // Record each operation position. Inorder to we can kown current operation - // should move after which operation. - DenseMap operationPosition; - SmallVector candidateOps; - size_t opCounter = 0; - - // get the position of each operation - func->walk([&](Operation *op) { - operationPosition[op] = opCounter++; - if (conditionalFunc(op)) - candidateOps.emplace_back(op); - }); - moveCandidateOperation(operationPosition, candidateOps); + std::queue> candidateOps; + getOperationPriority(func, candidateOps); + moveCandidateOperation(candidateOps, rewriter, start, end); // eliminate some useless operation - RewritePatternSet patterns(ctx); + RewritePatternSet patterns(rewriter.getContext()); (void)applyPatternsAndFoldGreedily(*func, std::move(patterns)); } diff --git a/test/mlir/test/gc/Transforms/cpu-physical-register.mlir b/test/mlir/test/gc/Transforms/cpu-physical-register.mlir index ae1b18f4b..5a97a3509 100644 --- a/test/mlir/test/gc/Transforms/cpu-physical-register.mlir +++ b/test/mlir/test/gc/Transforms/cpu-physical-register.mlir @@ -26,8 +26,8 @@ // CHECK: scf.for // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<11008x4096xf32>, vector<16xf32> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> -// CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[READ0]] : vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> +// CHECK: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[READ1]] : vector<16xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<11008x4096xf32> func.func @add_tensor_test0(%arg0: tensor<11008x4096xf32>, %arg1: tensor<11008x4096xf32>) -> tensor<11008x4096xf32> { %0 = tensor.empty() : tensor<11008x4096xf32> @@ -80,7 +80,7 @@ func.func @reduce_keepdimtest1(%arg0: tensor<16x32x64xf32>) -> tensor<16x1x64xf3 // CHECK: scf.for // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<512x512xf32>, vector<16xf32> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> // CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], {{.*}} : vector<16xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<16xf32>, tensor<512x512xf32> func.func @fc_relu_test2(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, @@ -108,7 +108,7 @@ func.func @fc_relu_test2(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, // CHECK: scf.for // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xf32>, vector<16xf32> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<32x32xf32> // CHECK: scf.yield // CHECK: scf.yield @@ -243,14 +243,14 @@ func.func @fuse_mlp_test4(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32x // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x512xbf16>, vector<32xbf16> // CHECK: %[[WRITE1:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x32x512xbf16> // CHECK-COUNT-4: scf.for -// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x16x1x32xbf16>, vector<32xbf16> // CHECK: %[[APPLY1:.*]] = affine.apply +// CHECK: %[[READ2:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x16x1x32xbf16>, vector<32xbf16> // CHECK: %[[WRITE2:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<1x512x32xbf16> // CHECK: %[[MATMUL0:.*]] = linalg.batch_reduce_matmul // CHECK-COUNT-2: scf.for -// CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xbf16>, vector<32xbf16> -// CHECK: %[[READ4:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32xbf16>, vector<32xbf16> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ3]], %[[READ4]] : vector<32xbf16> +// CHECK: %[[READ3:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32xbf16>, vector<32xbf16> +// CHECK: %[[READ4:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<32x32xbf16>, vector<32xbf16> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ4]], %[[READ3]] : vector<32xbf16> // CHECK: %[[EXP0:.*]] = math.exp %[[ADD0]] : vector<32xbf16> // CHECK: %[[WRITE3:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> // CHECK: %[[WRITE4:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<32xbf16>, tensor<32x32xbf16> @@ -490,9 +490,9 @@ func.func @elem_pack_transpose_inner_and_outer_dims2_test8(%arg0: tensor<64xf32> // CHECK: %[[C16:.*]] = arith.constant 16 : index // CHECK: scf.for // CHECK: scf.for -// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<2x16xf32>, vector<16xf32> -// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16xf32>, vector<16xf32> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<16xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}} {in_bounds = [true]} : tensor<2x16xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> // CHECK: %[[WRITE0:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true]} : vector<16xf32>, tensor<2x16xf32> func.func @broadcast_same_shape_test9(%input: tensor<16xf32>, %init: tensor<2x16xf32>) -> tensor<2x16xf32> { %empty = tensor.empty() : tensor<2x16xf32> @@ -654,7 +654,7 @@ func.func @reduce_fuse_test13(%input: tensor<16x32x64xf32>, // CHECK: scf.for // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<2xf32>, vector<1xf32> // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, {{.*}}: tensor<2xf32>, vector<1xf32> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<1xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<1xf32> // CHECK: %[[ADD1:.*]] = arith.maximumf %[[ADD0]], %[[CST]] : vector<1xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write {{.*}}, {{.*}} : vector<1xf32>, tensor<2xf32> func.func @add_small_tensor_test14(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { @@ -673,9 +673,9 @@ func.func @add_small_tensor_test14(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) - // CHECK: %[[C16:.*]] = arith.constant 16 : index // CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[arg3:.*]] = {{.*}}) -> (tensor<64x64xf32>) // CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<64x64xf32>) -// CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64x64xf32>, vector<16xf32> -// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64xf32>, vector<16xf32> -// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32> +// CHECK: %[[READ0:.*]] = vector.transfer_read %arg0[%[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64xf32>, vector<16xf32> +// CHECK: %[[READ1:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64x64xf32>, vector<16xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[READ0]], %[[READ1]] : vector<16xf32> // CHECK: %[[WRITE:.*]] = vector.transfer_write %[[ADD0]], %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<64x64xf32> func.func @broadcast_add_test15(%arg0: tensor<64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> { %0 = tensor.empty() : tensor<64x64xf32> From d49715c43d2f171e57db3c96af5c3f961069315d Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Thu, 26 Sep 2024 08:32:48 +0800 Subject: [PATCH 65/66] remove unused function --- lib/gc/Transforms/CPUPhysicalRegisterPass.cpp | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp index 4059e21a3..b8246d17a 100644 --- a/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp +++ b/lib/gc/Transforms/CPUPhysicalRegisterPass.cpp @@ -51,15 +51,6 @@ static inline bool isBroadcastOp(Operation *op) { return isa_and_nonnull(op); } -static inline bool isProducerOp(Operation *op) { - return isa(op); -} - -static inline bool isCandidateMoveOperations(Operation *op) { - return isa( - op); -} - static inline bool isReadOrWriteOperation(Operation *op) { return isa(op); } @@ -100,16 +91,14 @@ bool hasDynamicShape(Operation *op) { // Check operands data type. if (llvm::any_of(op->getOperands(), [&isDynamicShapedType](Value x) { return isDynamicShapedType(x); - })) { + })) return true; - } // Check results data type. if (llvm::any_of(op->getResults(), [&isDynamicShapedType](OpResult x) { return isDynamicShapedType(x); - })) { + })) return true; - } return false; } From 57bec503facc13c28ea0c75d163d79a78279f47c Mon Sep 17 00:00:00 2001 From: "Xu, Xiaohui1" Date: Tue, 8 Oct 2024 15:23:19 +0800 Subject: [PATCH 66/66] fix comments --- include/gc/Transforms/Utils/NumericUtils.h | 34 ++++ include/gc/Transforms/Utils/VectorUtils.h | 13 +- lib/gc/Transforms/Utils/CMakeLists.txt | 1 + lib/gc/Transforms/Utils/NumericUtils.cpp | 164 ++++++++++++++++++ lib/gc/Transforms/Utils/VectorUtils.cpp | 191 ++------------------- 5 files changed, 214 insertions(+), 189 deletions(-) create mode 100644 include/gc/Transforms/Utils/NumericUtils.h create mode 100644 lib/gc/Transforms/Utils/NumericUtils.cpp diff --git a/include/gc/Transforms/Utils/NumericUtils.h b/include/gc/Transforms/Utils/NumericUtils.h new file mode 100644 index 000000000..f47d9dace --- /dev/null +++ b/include/gc/Transforms/Utils/NumericUtils.h @@ -0,0 +1,34 @@ +//===-- NumericUtils.h - numeric utilities ----------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_TRANSFORMS_UTILS_NUMERICUTILS_H +#define GC_TRANSFORMS_UTILS_NUMERICUTILS_H +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include +#include +#include + +namespace mlir { +namespace gc { + +union Float32Bits { + uint32_t u; + float f; +}; +uint16_t float2half(float floatValue); +float half2float(uint16_t halfValue); +uint16_t float2bfloat(float floatValue); +float bfloat2float(uint16_t bfloatBits); +std::variant numeric_limits_minimum(Type type); +std::variant numericLimitsMaximum(Type type); + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/include/gc/Transforms/Utils/VectorUtils.h b/include/gc/Transforms/Utils/VectorUtils.h index e70d96155..89341bb82 100644 --- a/include/gc/Transforms/Utils/VectorUtils.h +++ b/include/gc/Transforms/Utils/VectorUtils.h @@ -1,4 +1,4 @@ -//===-- VectorUtils.h - vector fusion analysis ------------------*- C++ -*-===// +//===-- VectorUtils.h - vector utilities ------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,6 +8,7 @@ #ifndef GC_TRANSFORMS_UTILS_VECTORUTILS_H #define GC_TRANSFORMS_UTILS_VECTORUTILS_H +#include "gc/Transforms/Utils/NumericUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -96,16 +97,6 @@ int getNearestVectorStep(const int step); /// prev-op, may need to use result vectortype /// default will return the opeation result type mlir::FailureOr getOperationMaxVectorType(Operation *op); -union Float32Bits { - uint32_t u; - float f; -}; -uint16_t float2half(float floatValue); -float half2float(uint16_t halfValue); -uint16_t float2bfloat(float floatValue); -float bfloat2float(uint16_t bfloatBits); -std::variant numeric_limits_minimum(Type type); -std::variant numericLimitsMaximum(Type type); template T getInitValForReduce(vector::CombiningKind kind, Type t) { diff --git a/lib/gc/Transforms/Utils/CMakeLists.txt b/lib/gc/Transforms/Utils/CMakeLists.txt index 4742a37ac..3b045fca0 100644 --- a/lib/gc/Transforms/Utils/CMakeLists.txt +++ b/lib/gc/Transforms/Utils/CMakeLists.txt @@ -3,6 +3,7 @@ gc_add_mlir_library(GcUtilsIR StructuredOpMatcher.cpp ValueUtils.cpp VectorUtils.cpp + NumericUtils.cpp DEPENDS MLIRLinalgDialect diff --git a/lib/gc/Transforms/Utils/NumericUtils.cpp b/lib/gc/Transforms/Utils/NumericUtils.cpp new file mode 100644 index 000000000..e1af31994 --- /dev/null +++ b/lib/gc/Transforms/Utils/NumericUtils.cpp @@ -0,0 +1,164 @@ +//===- NumericUtils.cpp - numeric utilities ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "gc/Transforms/Utils/NumericUtils.h" + +namespace mlir { +namespace gc { + +const uint32_t kF32MantiBits = 23; +const uint32_t kF32HalfMantiBitDiff = 13; +const uint32_t kF32HalfBitDiff = 16; +const Float32Bits kF32Magic = {113 << kF32MantiBits}; +const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; +const uint32_t kF32BfMantiBitDiff = 16; + +/// Constructs the 16 bit representation for a half precision value from a float +/// value. This implementation is adapted from Eigen. +uint16_t float2half(float floatValue) { + const Float32Bits inf = {255 << kF32MantiBits}; + const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; + const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) + << kF32MantiBits}; + uint32_t signMask = 0x80000000u; + uint16_t halfValue = static_cast(0x0u); + Float32Bits f; + f.f = floatValue; + uint32_t sign = f.u & signMask; + f.u ^= sign; + + if (f.u >= f16max.u) { + const uint32_t halfQnan = 0x7e00; + const uint32_t halfInf = 0x7c00; + // Inf or NaN (all exponent bits set). + halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf + } else { + // (De)normalized number or zero. + if (f.u < kF32Magic.u) { + // The resulting FP16 is subnormal or zero. + // + // Use a magic value to align our 10 mantissa bits at the bottom of the + // float. As long as FP addition is round-to-nearest-even this works. + f.f += denormMagic.f; + + halfValue = static_cast(f.u - denormMagic.u); + } else { + uint32_t mantOdd = + (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. + + // Update exponent, rounding bias part 1. The following expressions are + // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + + // 0xfff`, but without arithmetic overflow. + f.u += 0xc8000fffU; + // Rounding bias part 2. + f.u += mantOdd; + halfValue = static_cast(f.u >> kF32HalfMantiBitDiff); + } + } + + halfValue |= static_cast(sign >> kF32HalfBitDiff); + return halfValue; +} + +/// Converts the 16 bit representation of a half precision value to a float +/// value. This implementation is adapted from Eigen. +float half2float(uint16_t halfValue) { + const uint32_t shiftedExp = + 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. + + // Initialize the float representation with the exponent/mantissa bits. + Float32Bits f = { + static_cast((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; + const uint32_t exp = shiftedExp & f.u; + f.u += kF32HalfExpAdjust; // Adjust the exponent + + // Handle exponent special cases. + if (exp == shiftedExp) { + // Inf/NaN + f.u += kF32HalfExpAdjust; + } else if (exp == 0) { + // Zero/Denormal? + f.u += 1 << kF32MantiBits; + f.f -= kF32Magic.f; + } + + f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. + return f.f; +} + +// Constructs the 16 bit representation for a bfloat value from a float value. +// This implementation is adapted from Eigen. +uint16_t float2bfloat(float floatValue) { + if (std::isnan(floatValue)) + return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; + + Float32Bits floatBits; + floatBits.f = floatValue; + uint16_t bfloatBits; + + // Least significant bit of resulting bfloat. + uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; + uint32_t roundingBias = 0x7fff + lsb; + floatBits.u += roundingBias; + bfloatBits = static_cast(floatBits.u >> kF32BfMantiBitDiff); + return bfloatBits; +} + +// Converts the 16 bit representation of a bfloat value to a float value. This +// implementation is adapted from Eigen. +float bfloat2float(uint16_t bfloatBits) { + Float32Bits floatBits; + floatBits.u = static_cast(bfloatBits) << kF32BfMantiBitDiff; + return floatBits.f; +} + +std::variant numeric_limits_minimum(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return -std::numeric_limits::infinity(); + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(-std::numeric_limits::infinity())); + } else if (t1.isF16()) { + return (float)half2float( + float2half(-std::numeric_limits::infinity())); + } else if (t1.isSignedInteger(8)) { + return int64_t(-128); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::min()); + } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { + return int64_t(0); + } else { + llvm_unreachable("unsupported data type"); + return (int64_t)0; + } +} + +std::variant numericLimitsMaximum(Type type) { + Type t1 = getElementTypeOrSelf(type); + if (t1.isF32()) { + return std::numeric_limits::infinity(); + } else if (t1.isBF16()) { + return bfloat2float(float2bfloat(std::numeric_limits::infinity())); + } else if (t1.isF16()) { + return (float)half2float( + float2half(std::numeric_limits::infinity())); + } else if (t1.isSignedInteger(8)) { + return int64_t(127); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::max()); + } else if (t1.isSignlessInteger(8)) { + return int64_t(255); + } else if (t1.isSignedInteger(32)) { + return int64_t(std::numeric_limits::max()); + } else { + llvm_unreachable("unsupported data type"); + return (int64_t)0; + } +} + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Utils/VectorUtils.cpp b/lib/gc/Transforms/Utils/VectorUtils.cpp index 0f148d41c..2751b04e2 100644 --- a/lib/gc/Transforms/Utils/VectorUtils.cpp +++ b/lib/gc/Transforms/Utils/VectorUtils.cpp @@ -1,4 +1,4 @@ -//===- VectorUtils.cpp - analysis vector ops --------------------*- C++ -*-===// +//===- VectorUtils.cpp - vector utilities -----------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "gc/Transforms/Utils/VectorUtils.h" #include "mlir/Support/LLVM.h" - namespace mlir { namespace gc { @@ -37,13 +36,10 @@ OPPRIORITY operator++(OPPRIORITY &c) { LogicalResult moveFront(Operation *op, IRRewriter &rewriter) { Operation *backOperation = nullptr; // check all the operand is block argument - bool allBlockArgs = true; - for (auto operand : op->getOperands()) { - if (!isa(operand)) { - allBlockArgs = false; - break; - } - } + bool allBlockArgs = llvm::all_of(op->getOperands(), [](Value operand) { + return isa(operand); + }); + if (allBlockArgs) { moveOpBeginingOfBlock(op, rewriter); return success(); @@ -153,31 +149,20 @@ void getOperationPriority( // get the position of each operation func->walk([&](Operation *op) { TypeSwitch(op) - .Case([&](affine::AffineApplyOp affineOp) { + .Case([&](auto op) { candidateOps.push(std::make_pair(op, OPPRIORITY::FIRST)); return; }) - .Case([&](tensor::ExtractSliceOp extractOp) { - candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND)); - return; - }) - .Case([&](tensor::EmptyOp emptyOp) { - candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND)); - return; - }) - .Case([&](tensor::InsertSliceOp insertOp) { - candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND)); - return; - }) - .Case([&](vector::TransferReadOp readOp) { + .Case( + [&](auto op) { + candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND)); + return; + }) + .Case([&](auto op) { candidateOps.push(std::make_pair(op, OPPRIORITY::LAST)); return; }) - .Case([&](vector::TransferWriteOp writeOp) { - candidateOps.push(std::make_pair(op, OPPRIORITY::LAST)); - return; - }) - .Case([&](vector::BroadcastOp bcOp) { + .Case([&](auto op) { candidateOps.push(std::make_pair(op, OPPRIORITY::THIRD)); return; }) @@ -221,156 +206,6 @@ void moveOpsFrontOrBack(func::FuncOp *func, IRRewriter &rewriter, (void)applyPatternsAndFoldGreedily(*func, std::move(patterns)); } -const uint32_t kF32MantiBits = 23; -const uint32_t kF32HalfMantiBitDiff = 13; -const uint32_t kF32HalfBitDiff = 16; -const Float32Bits kF32Magic = {113 << kF32MantiBits}; -const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits; -const uint32_t kF32BfMantiBitDiff = 16; - -/// Constructs the 16 bit representation for a half precision value from a float -/// value. This implementation is adapted from Eigen. -uint16_t float2half(float floatValue) { - const Float32Bits inf = {255 << kF32MantiBits}; - const Float32Bits f16max = {(127 + 16) << kF32MantiBits}; - const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1) - << kF32MantiBits}; - uint32_t signMask = 0x80000000u; - uint16_t halfValue = static_cast(0x0u); - Float32Bits f; - f.f = floatValue; - uint32_t sign = f.u & signMask; - f.u ^= sign; - - if (f.u >= f16max.u) { - const uint32_t halfQnan = 0x7e00; - const uint32_t halfInf = 0x7c00; - // Inf or NaN (all exponent bits set). - halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf - } else { - // (De)normalized number or zero. - if (f.u < kF32Magic.u) { - // The resulting FP16 is subnormal or zero. - // - // Use a magic value to align our 10 mantissa bits at the bottom of the - // float. As long as FP addition is round-to-nearest-even this works. - f.f += denormMagic.f; - - halfValue = static_cast(f.u - denormMagic.u); - } else { - uint32_t mantOdd = - (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd. - - // Update exponent, rounding bias part 1. The following expressions are - // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) + - // 0xfff`, but without arithmetic overflow. - f.u += 0xc8000fffU; - // Rounding bias part 2. - f.u += mantOdd; - halfValue = static_cast(f.u >> kF32HalfMantiBitDiff); - } - } - - halfValue |= static_cast(sign >> kF32HalfBitDiff); - return halfValue; -} - -/// Converts the 16 bit representation of a half precision value to a float -/// value. This implementation is adapted from Eigen. -float half2float(uint16_t halfValue) { - const uint32_t shiftedExp = - 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift. - - // Initialize the float representation with the exponent/mantissa bits. - Float32Bits f = { - static_cast((halfValue & 0x7fff) << kF32HalfMantiBitDiff)}; - const uint32_t exp = shiftedExp & f.u; - f.u += kF32HalfExpAdjust; // Adjust the exponent - - // Handle exponent special cases. - if (exp == shiftedExp) { - // Inf/NaN - f.u += kF32HalfExpAdjust; - } else if (exp == 0) { - // Zero/Denormal? - f.u += 1 << kF32MantiBits; - f.f -= kF32Magic.f; - } - - f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit. - return f.f; -} - -// Constructs the 16 bit representation for a bfloat value from a float value. -// This implementation is adapted from Eigen. -uint16_t float2bfloat(float floatValue) { - if (std::isnan(floatValue)) - return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0; - - Float32Bits floatBits; - floatBits.f = floatValue; - uint16_t bfloatBits; - - // Least significant bit of resulting bfloat. - uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1; - uint32_t roundingBias = 0x7fff + lsb; - floatBits.u += roundingBias; - bfloatBits = static_cast(floatBits.u >> kF32BfMantiBitDiff); - return bfloatBits; -} - -// Converts the 16 bit representation of a bfloat value to a float value. This -// implementation is adapted from Eigen. -float bfloat2float(uint16_t bfloatBits) { - Float32Bits floatBits; - floatBits.u = static_cast(bfloatBits) << kF32BfMantiBitDiff; - return floatBits.f; -} - -std::variant numeric_limits_minimum(Type type) { - Type t1 = getElementTypeOrSelf(type); - if (t1.isF32()) { - return -std::numeric_limits::infinity(); - } else if (t1.isBF16()) { - return bfloat2float(float2bfloat(-std::numeric_limits::infinity())); - } else if (t1.isF16()) { - return (float)half2float( - float2half(-std::numeric_limits::infinity())); - } else if (t1.isSignedInteger(8)) { - return int64_t(-128); - } else if (t1.isSignedInteger(32)) { - return int64_t(std::numeric_limits::min()); - } else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) { - return int64_t(0); - } else { - llvm_unreachable("unsupported data type"); - return (int64_t)0; - } -} - -std::variant numericLimitsMaximum(Type type) { - Type t1 = getElementTypeOrSelf(type); - if (t1.isF32()) { - return std::numeric_limits::infinity(); - } else if (t1.isBF16()) { - return bfloat2float(float2bfloat(std::numeric_limits::infinity())); - } else if (t1.isF16()) { - return (float)half2float( - float2half(std::numeric_limits::infinity())); - } else if (t1.isSignedInteger(8)) { - return int64_t(127); - } else if (t1.isSignedInteger(32)) { - return int64_t(std::numeric_limits::max()); - } else if (t1.isSignlessInteger(8)) { - return int64_t(255); - } else if (t1.isSignedInteger(32)) { - return int64_t(std::numeric_limits::max()); - } else { - llvm_unreachable("unsupported data type"); - return (int64_t)0; - } -} - mlir::FailureOr getOperationVectorType(Operation *op, bool isPrevOp) { if (not op)