-
Notifications
You must be signed in to change notification settings - Fork 14.5k
AMDGPU: Support intrinsic selection for gfx1250 wmma instructions #148957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-Authored-by: Stanislav Mekhanoshin <[email protected]> Co-Authored-by: Shilei Tian <Shilei.Tian.com>
@llvm/pr-subscribers-backend-amdgpu @llvm/pr-subscribers-llvm-analysis Author: Changpeng Fang (changpeng) ChangesPatch is 375.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148957.diff 20 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 962693003349e..acd10af1709ea 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2919,6 +2919,20 @@ def int_amdgcn_permlanex16_var : ClangBuiltin<"__builtin_amdgcn_permlanex16_var"
// the form: D = A * B + C.
// A is sparse matrix, half the size of B, and is expanded using sparsity index.
+class AMDGPUSWmmacIntrinsicIdxReuse<LLVMType A, LLVMType B, LLVMType CD, LLVMType Index> :
+ Intrinsic<
+ [CD], // %D
+ [
+ A, // %A
+ B, // %B
+ LLVMMatchType<0>, // %C
+ Index, // %Sparsity index for A
+ llvm_i1_ty, // matrix_a_reuse
+ llvm_i1_ty, // matrix_b_reuse
+ ],
+ [IntrNoMem, IntrConvergent, IntrWillReturn, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]
+>;
+
class AMDGPUSWmmacIntrinsicIdx<LLVMType A, LLVMType B, LLVMType CD, LLVMType Index> :
Intrinsic<
[CD], // %D
@@ -3602,6 +3616,161 @@ def int_amdgcn_fdiv_fast : DefaultAttrsIntrinsic<
[IntrNoMem, IntrSpeculatable]
>;
+// WMMA intrinsics.
+class AMDGPUWmmaIntrinsicModsAB<LLVMType AB, LLVMType CD> :
+ Intrinsic<
+ [CD], // %D
+ [
+ llvm_i1_ty, // %A_mod: 0 -- none, 1 -- neg
+ AB, // %A
+ llvm_i1_ty, // %B_mod: 0 -- none, 1 -- neg
+ LLVMMatchType<1>, // %B
+ LLVMMatchType<0>, // %C
+ llvm_i1_ty, // matrix_a_reuse
+ llvm_i1_ty, // matrix_b_reuse
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<5>>, ImmArg<ArgIndex<6>>,
+ IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+class AMDGPUWmmaIntrinsicModsC<LLVMType AB, LLVMType CD> :
+ Intrinsic<
+ [CD], // %D
+ [
+ AB, // %A
+ LLVMMatchType<1>, // %B
+ llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+ LLVMMatchType<0>, // %C
+ llvm_i1_ty, // matrix_a_reuse
+ llvm_i1_ty, // matrix_b_reuse
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>,
+ IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+class AMDGPUWmmaIntrinsicF4ModsC<LLVMType A, LLVMType B, LLVMType CD> :
+ Intrinsic<
+ [CD], // %D
+ [
+ A, // %A
+ B, // %B
+ llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
+ LLVMMatchType<0>, // %C
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<2>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+class AMDGPUWmmaIntrinsicModsAll<LLVMType AB, LLVMType CD> :
+ Intrinsic<
+ [CD], // %D
+ [
+ llvm_i1_ty, // %A_mod: 0 -- none, 1 -- neg
+ AB, // %A
+ llvm_i1_ty, // %B_mod: 0 -- none, 1 -- neg
+ LLVMMatchType<1>, // %B
+ llvm_i16_ty, // %C_mod: 0 -- none, 1 -- neg, 2 -- abs, 3 -- neg(abs)
+ LLVMMatchType<0>, // %C
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+class AMDGPUWmmaIntrinsicModsAllReuse<LLVMType AB, LLVMType CD> :
+ Intrinsic<
+ [CD], // %D
+ [
+ llvm_i1_ty, // %A_mod: 0 -- none, 1 -- neg
+ AB, // %A
+ llvm_i1_ty, // %B_mod: 0 -- none, 1 -- neg
+ LLVMMatchType<1>, // %B
+ llvm_i16_ty, // %C_mod: 0 -- none, 1 -- neg, 2 -- abs, 3 -- neg(abs)
+ LLVMMatchType<0>, // %C
+ llvm_i1_ty, // matrix_a_reuse
+ llvm_i1_ty, // matrix_b_reuse
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<7>>,
+ IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+// D and C are of different types.
+class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
+ Intrinsic<
+ [DstTy], // %D
+ [
+ llvm_i1_ty, // %A_mod: 0 -- none, 1 -- neg
+ AB, // %A
+ llvm_i1_ty, // %B_mod: 0 -- none, 1 -- neg
+ LLVMMatchType<1>, // %B
+ llvm_i16_ty, // %C_mod: 0 -- none, 1 -- neg, 2 -- abs, 3 -- neg(abs)
+ C, // %C
+ llvm_i1_ty, // matrix_a_reuse
+ llvm_i1_ty, // matrix_b_reuse
+ ],
+ [IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<7>>,
+ IntrWillReturn, IntrNoCallback, IntrNoFree]
+>;
+
+defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
+def int_amdgcn_wmma_f64_16x16x4_f64 : AMDGPUWmmaIntrinsicModsAll<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x32_f16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x32_f16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_bf16_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x64_fp8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x64_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x64_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x64_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x64_fp8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x64_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x64_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x64_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x128_fp8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f16_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x128_fp8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_i32_16x16x64_iu8 : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_f32_32x16x128_f4 : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
+}
+
+class AMDGPUSWmmacIntrinsicABIdx<LLVMType A, LLVMType B, LLVMType CD, LLVMType Index> :
+ Intrinsic<
+ [CD], // %D
+ [
+ llvm_i1_ty, // %A_mod: 0 - none, 1 - neg
+ A, // %A
+ llvm_i1_ty, // %B_mod: 0 - none, 1 - neg
+ B, // %B
+ LLVMMatchType<0>, // %C
+ Index, // %Sparsity index for A
+ llvm_i1_ty, // matrix_a_reuse
+ llvm_i1_ty, // matrix_b_reuse
+ ],
+ [IntrNoMem, IntrConvergent, IntrWillReturn, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<7>>]
+>;
+
+defset list<Intrinsic> AMDGPUSWMMACIntrinsicsGFX1250 = {
+def int_amdgcn_swmmac_f32_16x16x64_f16 : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x64_bf16 : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x64_f16 : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_bf16_16x16x64_bf16 : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_bf16f32_16x16x64_bf16 : AMDGPUSWmmacIntrinsicABIdx<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x128_fp8_fp8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x128_fp8_bf8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x128_bf8_fp8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f32_16x16x128_bf8_bf8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x128_fp8_fp8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x128_fp8_bf8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x128_bf8_fp8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_f16_16x16x128_bf8_bf8 : AMDGPUSWmmacIntrinsicIdxReuse<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty, llvm_anyint_ty>;
+def int_amdgcn_swmmac_i32_16x16x128_iu8 : AMDGPUSWmmacIntrinsicABIdx<llvm_anyint_ty, llvm_anyint_ty, llvm_anyint_ty, llvm_anyint_ty>;
+}
+
+
class AMDGPUTensorLoadStore:
Intrinsic<
[],
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.td b/llvm/lib/Target/AMDGPU/AMDGPU.td
index 04ec1716478e4..712dcadf2d317 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.td
@@ -844,6 +844,12 @@ def FeatureCvtFP8VOP1Bug : SubtargetFeature<"cvt-fp8-vop1-bug",
[FeatureFP8ConversionInsts]
>;
+def FeatureWMMA128bInsts : SubtargetFeature<"wmma-128b-insts",
+ "HasWMMA128bInsts",
+ "true",
+ "Has WMMA instructions where A and B matrices do not have duplicated data"
+>;
+
def FeaturePkFmacF16Inst : SubtargetFeature<"pk-fmac-f16-inst",
"HasPkFmacF16Inst",
"true",
@@ -1925,6 +1931,7 @@ def FeatureISAVersion12 : FeatureSet<
FeatureImageInsts,
FeatureExtendedImageInsts,
FeatureFP8ConversionInsts,
+ FeatureWMMA128bInsts,
FeatureIEEEMinimumMaximumInsts,
FeaturePackedTID,
FeatureVcmpxPermlaneHazard,
@@ -2291,6 +2298,10 @@ def isGFX11Plus :
Predicate<"Subtarget->getGeneration() >= AMDGPUSubtarget::GFX11">,
AssemblerPredicate<(all_of FeatureGFX11Insts)>;
+def isGFX11PlusNot12_50 :
+ Predicate<"Subtarget->getGeneration() >= AMDGPUSubtarget::GFX11 && !Subtarget->hasGFX1250Insts()">,
+ AssemblerPredicate<(all_of FeatureGFX11Insts, (not FeatureGFX1250Insts))>;
+
def isGFX12Only :
Predicate<"Subtarget->getGeneration() == AMDGPUSubtarget::GFX12">,
AssemblerPredicate<(all_of FeatureGFX12Insts)>;
@@ -2616,6 +2627,9 @@ def HasFP8Insts : Predicate<"Subtarget->hasFP8Insts()">,
def HasFP8ConversionInsts : Predicate<"Subtarget->hasFP8ConversionInsts()">,
AssemblerPredicate<(all_of FeatureFP8ConversionInsts)>;
+def HasWMMA128bInsts : Predicate<"Subtarget->hasWMMA128bInsts()">,
+ AssemblerPredicate<(all_of FeatureWMMA128bInsts)>;
+
def HasFP8E5M3Insts : Predicate<"Subtarget->hasFP8E5M3Insts()">,
AssemblerPredicate<(all_of FeatureFP8E5M3Insts)>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
index 1b909568fc555..7b5d4077e85f3 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td
@@ -55,6 +55,14 @@ def gi_vop3pmodsneg :
GIComplexOperandMatcher<s32, "selectVOP3PModsNeg">,
GIComplexPatternEquiv<VOP3PModsNeg>;
+def gi_vop3pmodsnegs :
+ GIComplexOperandMatcher<s32, "selectVOP3PModsNegs">,
+ GIComplexPatternEquiv<VOP3PModsNegs>;
+
+def gi_dotiuvop3pmodsnegabs :
+ GIComplexOperandMatcher<s32, "selectVOP3PModsNegAbs">,
+ GIComplexPatternEquiv<VOP3PModsNegAbs>;
+
def gi_wmmaopselvop3pmods :
GIComplexOperandMatcher<s32, "selectWMMAOpSelVOP3PMods">,
GIComplexPatternEquiv<WMMAOpSelVOP3PMods>;
@@ -83,6 +91,10 @@ def gi_swmmacindex16 :
GIComplexOperandMatcher<s32, "selectSWMMACIndex16">,
GIComplexPatternEquiv<SWMMACIndex16>;
+def gi_swmmacindex32 :
+ GIComplexOperandMatcher<s64, "selectSWMMACIndex32">,
+ GIComplexPatternEquiv<SWMMACIndex32>;
+
def gi_vop3opselmods :
GIComplexOperandMatcher<s32, "selectVOP3OpSelMods">,
GIComplexPatternEquiv<VOP3OpSelMods>;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index 620eac428c084..25672a52345cb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -3273,6 +3273,7 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PModsDOT(SDValue In, SDValue &Src,
return SelectVOP3PMods(In, Src, SrcMods, true);
}
+// Select neg_lo from the i1 immediate operand.
bool AMDGPUDAGToDAGISel::SelectVOP3PModsNeg(SDValue In, SDValue &Src) const {
const ConstantSDNode *C = cast<ConstantSDNode>(In);
// Literal i1 value set in intrinsic, represents SrcMods for the next operand.
@@ -3288,6 +3289,47 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PModsNeg(SDValue In, SDValue &Src) const {
return true;
}
+// Select both neg_lo and neg_hi from the i1 immediate operand. This is
+// specifically for F16/BF16 operands in WMMA instructions, where neg_lo applies
+// to matrix's even k elements, and neg_hi applies to matrix's odd k elements.
+bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegs(SDValue In, SDValue &Src) const {
+ const ConstantSDNode *C = cast<ConstantSDNode>(In);
+ // Literal i1 value set in intrinsic, represents SrcMods for the next operand.
+ // 1 promotes packed values to signed, 0 treats them as unsigned.
+ assert(C->getAPIntValue().getBitWidth() == 1 && "expected i1 value");
+
+ unsigned Mods = SISrcMods::OP_SEL_1;
+ unsigned SrcSign = C->getZExtValue();
+ if (SrcSign == 1)
+ Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
+
+ Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+ return true;
+}
+
+// Select neg, abs, or both neg and abs from the i16 immediate operans.
+bool AMDGPUDAGToDAGISel::SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const {
+ const ConstantSDNode *C = cast<ConstantSDNode>(In);
+ unsigned Mods = SISrcMods::OP_SEL_1;
+ unsigned SrcMod = C->getZExtValue();
+ switch (SrcMod) {
+ default: // Any other value will be silently ignored (considered as 0).
+ break;
+ case 1:
+ Mods ^= SISrcMods::NEG;
+ break;
+ case 2:
+ Mods ^= SISrcMods::ABS;
+ break;
+ case 3:
+ Mods ^= (SISrcMods::NEG | SISrcMods::ABS);
+ break;
+ }
+
+ Src = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+ return true;
+}
+
bool AMDGPUDAGToDAGISel::SelectWMMAOpSelVOP3PMods(SDValue In,
SDValue &Src) const {
const ConstantSDNode *C = cast<ConstantSDNode>(In);
@@ -3639,6 +3681,41 @@ bool AMDGPUDAGToDAGISel::SelectSWMMACIndex16(SDValue In, SDValue &Src,
return true;
}
+bool AMDGPUDAGToDAGISel::SelectSWMMACIndex32(SDValue In, SDValue &Src,
+ SDValue &IndexKey) const {
+ unsigned Key = 0;
+ Src = In;
+
+ SDValue InI32;
+
+ if (In.getOpcode() == ISD::ANY_EXTEND || In.getOpcode() == ISD::ZERO_EXTEND) {
+ const SDValue &ExtendSrc = In.getOperand(0);
+ if (ExtendSrc.getValueSizeInBits() == 32)
+ InI32 = ExtendSrc;
+ } else if (In->getOpcode() == ISD::BITCAST) {
+ const SDValue &CastSrc = In.getOperand(0);
+ if (CastSrc.getOpcode() == ISD::BUILD_VECTOR &&
+ CastSrc.getOperand(0).getValueSizeInBits() == 32) {
+ ConstantSDNode *Zero = dyn_cast<ConstantSDNode>(CastSrc.getOperand(1));
+ if (Zero && Zero->getZExtValue() == 0)
+ InI32 = CastSrc.getOperand(0);
+ }
+ }
+
+ if (InI32 && InI32.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+ const SDValue &ExtractVecEltSrc = InI32.getOperand(0);
+ ConstantSDNode *EltIdx = dyn_cast<ConstantSDNode>(InI32.getOperand(1));
+ if (ExtractVecEltSrc.getValueSizeInBits() == 64 && EltIdx &&
+ EltIdx->getZExtValue() == 1) {
+ Key = 1;
+ Src = ExtractVecEltSrc;
+ }
+ }
+
+ IndexKey = CurDAG->getTargetConstant(Key, SDLoc(In), MVT::i32);
+ return true;
+}
+
bool AMDGPUDAGToDAGISel::SelectVOP3OpSel(SDValue In, SDValue &Src,
SDValue &SrcMods) const {
Src = In;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
index f3b9364fdb92b..9967f46e085e4 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
@@ -222,6 +222,8 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
bool SelectVOP3PModsDOT(SDValue In, SDValue &Src, SDValue &SrcMods) const;
bool SelectVOP3PModsNeg(SDValue In, SDValue &Src) const;
+ bool SelectVOP3PModsNegs(SDValue In, SDValue &Src) const;
+ bool SelectVOP3PModsNegAbs(SDValue In, SDValue &Src) const;
bool SelectWMMAOpSelVOP3PMods(SDValue In, SDValue &Src) const;
bool SelectWMMAModsF32NegAbs(SDValue In, SDValue &Src,
@@ -233,6 +235,7 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
bool SelectSWMMACIndex8(SDValue In, SDValue &Src, SDValue &IndexKey) const;
bool SelectSWMMACIndex16(SDValue In, SDValue &Src, SDValue &IndexKey) const;
+ bool SelectSWMMACIndex32(SDValue In, SDValue &Src, SDValue &IndexKey) const;
bool SelectVOP3OpSel(SDValue In, SDValue &Src, SDValue &SrcMods) const;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index ea79c57080faa..1a63c48e3666c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -3513,6 +3513,25 @@ static Register matchZeroExtendFromS32(MachineRegisterInfo &MRI, Register Reg) {
return Register();
}
+Register AMDGPUInstructionSelector::matchAnyExtendFromS32(Register Reg) const {
+ Register AnyExtSrc;
+ if (mi_match(Reg, *MRI, m_GAnyExt(m_Reg(AnyExtSrc))))
+ return MRI->getType(AnyExtSrc) == LLT::scalar(32) ? AnyExtSrc : Register();
+
+ // Match legalized form %zext = G_MERGE_VALUES (s32 %x), (s32 G_IMPLICIT_DEF)
+ const MachineInstr *Def = getDefIgnoringCopies(Reg, *MRI);
+ if (Def->getOpcode() != AMDGPU::G_MERGE_VALUES)
+ return Register();
+
+ assert(Def->getNumOperands() == 3 &&
+ MRI->getType(Def->getOperand(0).getReg()) == LLT::scalar(64));
+
+ if (mi_match(Def->getOperand(2).getReg(), *MRI, m_GImplicitDef()))
+ return Def->getOperand(1).getReg();
+
+ return Register();
+}
+
bool AMDGPUInstructionSelector::selectGlobalLoadLds(MachineInstr &MI) const{
if (!Subtarget->hasVMemToLDSLoad())
return false;
@@ -4904,6 +4923,7 @@ AMDGPUInstructionSelector::selectVOP3PModsDOT(MachineOperand &Root) const {
return selectVOP3PRetHelper(Root, true);
}
+// Select neg_lo from the i1 immediate operand.
InstructionSelector::ComplexRendererFns
AMDGPUInstructionSelector::selectVOP3PModsNeg(MachineOperand &Root) const {
// Literal i1 value set in intrinsic, represents SrcMods for the next operand.
@@ -4919,6 +4939,50 @@ AMDGPUInstructionSelector::selectVOP3PModsNeg(MachineOperand &Root) const {
}};
}
+// Select both neg_lo and neg_hi from the i1 immediate operand. This is
+// specifically for F16/BF16 operands in WMMA instructions, where neg_lo applies
+// to matrix's even k elements, and neg_hi applies to matrix's odd k elements.
+InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectVOP3PModsNegs(MachineOperand &Root) const {
+ // Literal i1 va...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Fix a format error and remove a useless Predicate
Remove unnecessary Feature and Predicates
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/3/builds/19082 Here is the relevant piece of the build log for the reference
|
No description provided.