Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the 2d block IO for tensor of pointer when the tensor on memory are contiguous. #3482

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

chengjunlu
Copy link
Contributor

@chengjunlu chengjunlu commented Feb 21, 2025

General

This is an alternative way to lower the tt.load with tensor of pointers to the 2D block IO.
It depends on the analysis result from the ModuleAxisInfoAnalysis about the pointers and masks.

Background

Both the block pointer type (e:g: !tt.ptr<tensor<64x32xf16>) and the tensor of pointer type (e.g: tensor<64x32x!tt.ptr<f16>>) are used to describe a tensor resident on global memory.
The difference is that the tensor of pointers contains more "entropy" than block pointer. The tensor of pointers can describe a tensor on global memory which is randomly distributed, of cause it can be used to describe the structed distribution as block pointer.

There already is a optimization pass to raise the tensor of pointer to the block pointer for some cases with limitation.
This way supports more cases with less limitation to lower the tt.load with tensor of pointer to 2D block IO.

Idea

The ModuleAxisInfoAnalysis analysis the axis information such as, contiguity, divisibility and constancy of the value of the tensor for pointers and masks.
In general, we can use the 2D block IO lowering as long as:

  1. the contiguity of the pointers is multiple of the threadsPerWarp size in the same Dim D.
  2. The constancy of the masks is multiple of the threadsPerWarp size in the same Dim D.

We start with the case of the tt.load with DotOp and DPAS layout. To generalize the lowering code for more cases with LinearLayout in future.

@chengjunlu
Copy link
Contributor Author

I will add more comprehensive LIT test cases.

@chengjunlu chengjunlu marked this pull request as draft February 21, 2025 07:14
@chengjunlu chengjunlu force-pushed the chengjun/tensorptr_blockio branch from 6e6f4c7 to 787dd41 Compare February 24, 2025 12:48
@chengjunlu chengjunlu marked this pull request as ready for review February 24, 2025 12:49
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");

auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
if (hasDpasLayout) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you have dpasLayout above you can probably remove this branch. Then again, will we eventually be able to sync this function w/ the code in rewrite tensor pointer load?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove the DPAS branch at first for simplicity.

The idea is that we can finally unify the tt.load lowering to block IO in one pattern class. The plan is to do it in several steps:

  1. Unify the tt.load lowering to block IO for both tensor of pointers and block pointer.
  2. Support DPAS and DotOp with DPAS layout for.
  3. To support general cases with the LinearLayout: different shape of tensor, different layout of the value returned.

@LiyangLingIntel
Copy link
Contributor

Triggerd https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/13561125937 to check the performance improvement.

@etiotto
Copy link
Contributor

etiotto commented Mar 3, 2025

Have you tested what happen to a workload that use masked loads (that are loop variant) and the mask switches from true to false in the middle of the loop iteration (or viceversa) ?

Note: I have enabled (by default) a pass to version the loop for commonly masked loads already. Given that is in Triton do we need this changes ?

@chengjunlu
Copy link
Contributor Author

chengjunlu commented Mar 3, 2025

Have you tested what happen to a workload that use masked loads (that are loop variant) ?

I uses the matmul kernel in tutorial 09 to test the masked loads which is not supported by raise block ptr at that time. We can only use the tensor of pointers. The first version of the performance is only about 30%~40% compare to the block pointer matmul in benchmark.

The masked loads with others are lowered to the if-else branches which cannot be optimized by IGC.

Functional worked but need more effort to improve the performance like integrate versioning pass and some others.

@etiotto
Copy link
Contributor

etiotto commented Mar 3, 2025

Have you tested what happen to a workload that use masked loads (that are loop variant) ?

I uses the matmul kernel in tutorial 09 to test the masked loads which is not supported by raise block ptr at that time. We can only use the tensor of pointers. The first version of the performance is only about 30%~40% compare to the block pointer matmul in benchmark.

The masked loads with others are lowered to the if-else branches which cannot be optimized by IGC.

Functional worked but need more effort to improve the performance like integrate versioning pass and some others.

Ok this is promising. The loop versioning pass that landed last week (#3516) should version the loop and it will then contain unmasked tt.load operation. IGC will not need to deal with branches in the loop. BTW, I targeted the loop versioning pass using tutorial 03 so, at the moment, I am not sure whether it will version the loop in tutorial 09 (but the pass can be extended).

return success();
}

private:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove


LogicalResult
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override -> final

Comment on lines +515 to +518
auto encoding = tensorType.getEncoding();
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto encoding = tensorType.getEncoding();
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
return failure();
const bool hasDpasLayout = hasDpasEncoding(tensorType);
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
return failure();
auto encoding = cast<DPASEncodingAttr>(tensorType.getEncoding());

Comment on lines +528 to +535
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
if (hasDpasLayout) {
return DpasEncodingAttr::OpIdx::OperandC;
} else {
auto dotLayout = getDotEncoding(tensorType).value();
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
if (hasDpasLayout) {
return DpasEncodingAttr::OpIdx::OperandC;
} else {
auto dotLayout = getDotEncoding(tensorType).value();
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
}
};
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
if (hasDpasLayout)
return DpasEncodingAttr::OpIdx::OperandC;
assert(hasDotDpasEncoding(tensorType) && "Expecting dot layout);
auto dotLayout = getDotEncoding(tensorType).value();
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
};

return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
}
};
auto opIdx = getOpIdx();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto -> DpasEncodingAttr::OpIdx

SmallVector<Value> ptrElems, maskElems, otherElems;
// Get the LLVM values for pointers
ptrElems = unpackLLElements(loc, llPtr, rewriter);
assert(ptrElems.size() == numElems);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add assert message

Comment on lines +819 to +824
auto offsetOuter =
outer * repOuterStride +
rep * dpasInstShape[dimOuter] * numOperandsOuterDimPerLoad;
auto offsetInner = inner * dpasInstShape[dimInner];
auto offsetM = (isOperandA ? offsetOuter : offsetInner);
auto offsetN = (isOperandA ? offsetInner : offsetOuter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto offsetOuter =
outer * repOuterStride +
rep * dpasInstShape[dimOuter] * numOperandsOuterDimPerLoad;
auto offsetInner = inner * dpasInstShape[dimInner];
auto offsetM = (isOperandA ? offsetOuter : offsetInner);
auto offsetN = (isOperandA ? offsetInner : offsetOuter);
unsigned offsetOuter =
outer * repOuterStride +
rep * dpasInstShape[dimOuter] * numOperandsOuterDimPerLoad;
unsigned offsetInner = inner * dpasInstShape[dimInner];
unsigned offsetM = (isOperandA ? offsetOuter : offsetInner);
unsigned offsetN = (isOperandA ? offsetInner : offsetOuter);

pred = targetInfo.shuffleIdx(rewriter, loc, pred, 0);
Value other_ = b.undef(load2DGenXType);
if (others.size()) {
auto vecTy = vec_ty(eltTy, numValuesPerLoad * packedElemsNum);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto vecTy = vec_ty(eltTy, numValuesPerLoad * packedElemsNum);
VectorType vecTy = vec_ty(eltTy, numValuesPerLoad * packedElemsNum);

Comment on lines +843 to +846
auto N = packedCol +
col * threadsPerWarp * numColPerPackedValue +
vblk * tileWidth + offsetN;
auto M = i + offsetM;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto N = packedCol +
col * threadsPerWarp * numColPerPackedValue +
vblk * tileWidth + offsetN;
auto M = i + offsetM;
unsigned N = packedCol +
col * threadsPerWarp * numColPerPackedValue +
vblk * tileWidth + offsetN;
unsigned M = i + offsetM;

// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm

// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather long tests, can they be reduced and simplified please?

Copy link
Contributor

@etiotto etiotto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left several inline comments that should be addressed prior to merging this PR. The approach LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Performance] Enable 2D Block IO lowering for tt.load with tensor of pointer
4 participants