-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable the 2d block IO for tensor of pointer when the tensor on memory are contiguous. #3482
base: main
Are you sure you want to change the base?
Conversation
I will add more comprehensive LIT test cases. |
6e6f4c7
to
787dd41
Compare
…y are contiguous.
const bool memoryRowMajor = (memoryLayoutInfo == "row_major"); | ||
|
||
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx { | ||
if (hasDpasLayout) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since you have dpasLayout
above you can probably remove this branch. Then again, will we eventually be able to sync this function w/ the code in rewrite tensor pointer load?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove the DPAS branch at first for simplicity.
The idea is that we can finally unify the tt.load
lowering to block IO in one pattern class. The plan is to do it in several steps:
- Unify the
tt.load
lowering to block IO for both tensor of pointers and block pointer. - Support DPAS and DotOp with DPAS layout for.
- To support general cases with the LinearLayout: different shape of tensor, different layout of the value returned.
Triggerd https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/13561125937 to check the performance improvement. |
Have you tested what happen to a workload that use masked loads (that are loop variant) and the mask switches from true to false in the middle of the loop iteration (or viceversa) ? Note: I have enabled (by default) a pass to version the loop for commonly masked loads already. Given that is in Triton do we need this changes ? |
I uses the matmul kernel in tutorial 09 to test the masked loads which is not supported by raise block ptr at that time. We can only use the tensor of pointers. The first version of the performance is only about 30%~40% compare to the block pointer matmul in benchmark. The masked loads with others are lowered to the if-else branches which cannot be optimized by IGC. Functional worked but need more effort to improve the performance like integrate versioning pass and some others. |
Ok this is promising. The loop versioning pass that landed last week (#3516) should version the loop and it will then contain unmasked |
return success(); | ||
} | ||
|
||
private: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
|
||
LogicalResult | ||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
override -> final
auto encoding = tensorType.getEncoding(); | ||
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding); | ||
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType)) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto encoding = tensorType.getEncoding(); | |
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding); | |
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType)) | |
return failure(); | |
const bool hasDpasLayout = hasDpasEncoding(tensorType); | |
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType)) | |
return failure(); | |
auto encoding = cast<DPASEncodingAttr>(tensorType.getEncoding()); |
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx { | ||
if (hasDpasLayout) { | ||
return DpasEncodingAttr::OpIdx::OperandC; | ||
} else { | ||
auto dotLayout = getDotEncoding(tensorType).value(); | ||
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx()); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx { | |
if (hasDpasLayout) { | |
return DpasEncodingAttr::OpIdx::OperandC; | |
} else { | |
auto dotLayout = getDotEncoding(tensorType).value(); | |
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx()); | |
} | |
}; | |
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx { | |
if (hasDpasLayout) | |
return DpasEncodingAttr::OpIdx::OperandC; | |
assert(hasDotDpasEncoding(tensorType) && "Expecting dot layout); | |
auto dotLayout = getDotEncoding(tensorType).value(); | |
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx()); | |
}; |
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx()); | ||
} | ||
}; | ||
auto opIdx = getOpIdx(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto -> DpasEncodingAttr::OpIdx
SmallVector<Value> ptrElems, maskElems, otherElems; | ||
// Get the LLVM values for pointers | ||
ptrElems = unpackLLElements(loc, llPtr, rewriter); | ||
assert(ptrElems.size() == numElems); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add assert message
auto offsetOuter = | ||
outer * repOuterStride + | ||
rep * dpasInstShape[dimOuter] * numOperandsOuterDimPerLoad; | ||
auto offsetInner = inner * dpasInstShape[dimInner]; | ||
auto offsetM = (isOperandA ? offsetOuter : offsetInner); | ||
auto offsetN = (isOperandA ? offsetInner : offsetOuter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto offsetOuter = | |
outer * repOuterStride + | |
rep * dpasInstShape[dimOuter] * numOperandsOuterDimPerLoad; | |
auto offsetInner = inner * dpasInstShape[dimInner]; | |
auto offsetM = (isOperandA ? offsetOuter : offsetInner); | |
auto offsetN = (isOperandA ? offsetInner : offsetOuter); | |
unsigned offsetOuter = | |
outer * repOuterStride + | |
rep * dpasInstShape[dimOuter] * numOperandsOuterDimPerLoad; | |
unsigned offsetInner = inner * dpasInstShape[dimInner]; | |
unsigned offsetM = (isOperandA ? offsetOuter : offsetInner); | |
unsigned offsetN = (isOperandA ? offsetInner : offsetOuter); |
pred = targetInfo.shuffleIdx(rewriter, loc, pred, 0); | ||
Value other_ = b.undef(load2DGenXType); | ||
if (others.size()) { | ||
auto vecTy = vec_ty(eltTy, numValuesPerLoad * packedElemsNum); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto vecTy = vec_ty(eltTy, numValuesPerLoad * packedElemsNum); | |
VectorType vecTy = vec_ty(eltTy, numValuesPerLoad * packedElemsNum); |
auto N = packedCol + | ||
col * threadsPerWarp * numColPerPackedValue + | ||
vblk * tileWidth + offsetN; | ||
auto M = i + offsetM; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto N = packedCol + | |
col * threadsPerWarp * numColPerPackedValue + | |
vblk * tileWidth + offsetN; | |
auto M = i + offsetM; | |
unsigned N = packedCol + | |
col * threadsPerWarp * numColPerPackedValue + | |
vblk * tileWidth + offsetN; | |
unsigned M = i + offsetM; |
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm | ||
|
||
// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt | ||
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather long tests, can they be reduced and simplified please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left several inline comments that should be addressed prior to merging this PR. The approach LGTM.
General
This is an alternative way to lower the
tt.load
with tensor of pointers to the 2D block IO.It depends on the analysis result from the
ModuleAxisInfoAnalysis
about the pointers and masks.Background
Both the block pointer type (e:g:
!tt.ptr<tensor<64x32xf16>
) and the tensor of pointer type (e.g:tensor<64x32x!tt.ptr<f16>>
) are used to describe a tensor resident on global memory.The difference is that the tensor of pointers contains more "entropy" than block pointer. The tensor of pointers can describe a tensor on global memory which is randomly distributed, of cause it can be used to describe the structed distribution as block pointer.
There already is a optimization pass to raise the tensor of pointer to the block pointer for some cases with limitation.
This way supports more cases with less limitation to lower the
tt.load
with tensor of pointer to 2D block IO.Idea
The
ModuleAxisInfoAnalysis
analysis the axis information such as, contiguity, divisibility and constancy of the value of the tensor for pointers and masks.In general, we can use the 2D block IO lowering as long as:
We start with the case of the
tt.load
with DotOp and DPAS layout. To generalize the lowering code for more cases with LinearLayout in future.