Skip to content

[NVPTX] Allow directly storing immediates to improve readability #145552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 26, 2025

Conversation

AlexMaclean
Copy link
Member

Allow directly storing an immediate instead of requiring that it first be moved into a register. This makes for more compact and readable PTX. An approach similar to this (using a ComplexPattern) this could be used for most PTX instructions to avoid the need for _[ri]+ variants and boiler-plate.

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Allow directly storing an immediate instead of requiring that it first be moved into a register. This makes for more compact and readable PTX. An approach similar to this (using a ComplexPattern) this could be used for most PTX instructions to avoid the need for _[ri]+ variants and boiler-plate.


Patch is 30.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145552.diff

16 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+22-10)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+38-25)
  • (modified) llvm/test/CodeGen/NVPTX/access-non-generic.ll (+1-2)
  • (modified) llvm/test/CodeGen/NVPTX/chain-different-as.ll (+4-5)
  • (modified) llvm/test/CodeGen/NVPTX/demote-vars.ll (+1-2)
  • (modified) llvm/test/CodeGen/NVPTX/i1-load-lower.ll (+2-3)
  • (modified) llvm/test/CodeGen/NVPTX/i128-ld-st.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/jump-table.ll (+5-9)
  • (modified) llvm/test/CodeGen/NVPTX/local-stack-frame.ll (+5-8)
  • (modified) llvm/test/CodeGen/NVPTX/lower-alloca.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/lower-byval-args.ll (+4-6)
  • (modified) llvm/test/CodeGen/NVPTX/param-align.ll (+8-14)
  • (modified) llvm/test/CodeGen/NVPTX/pr13291-i1-store.ll (+2-4)
  • (modified) llvm/test/CodeGen/NVPTX/reg-types.ll (+36-53)
  • (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+33-42)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ff10eea371049..8f5a3a4f72234 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1364,20 +1364,18 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
   SDValue Offset, Base;
   SelectADDR(ST->getBasePtr(), Base, Offset);
 
-  SDValue Ops[] = {Value,
+  SDValue Ops[] = {selectPossiblyImm(Value),
                    getI32Imm(Ordering, DL),
                    getI32Imm(Scope, DL),
                    getI32Imm(CodeAddrSpace, DL),
-                   getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
                    getI32Imm(ToTypeWidth, DL),
                    Base,
                    Offset,
                    Chain};
 
-  const MVT::SimpleValueType SourceVT =
-      Value.getNode()->getSimpleValueType(0).SimpleTy;
-  const std::optional<unsigned> Opcode = pickOpcodeForVT(
-      SourceVT, NVPTX::ST_i8, NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64);
+  const std::optional<unsigned> Opcode =
+      pickOpcodeForVT(Value.getSimpleValueType().SimpleTy, NVPTX::ST_i8,
+                      NVPTX::ST_i16, NVPTX::ST_i32, NVPTX::ST_i64);
   if (!Opcode)
     return false;
 
@@ -1414,7 +1412,9 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
 
   const unsigned NumElts = getLoadStoreVectorNumElts(ST);
 
-  SmallVector<SDValue, 16> Ops(ST->ops().slice(1, NumElts));
+  SmallVector<SDValue, 16> Ops;
+  for (auto &V : ST->ops().slice(1, NumElts))
+    Ops.push_back(selectPossiblyImm(V));
   SDValue Addr = N->getOperand(NumElts + 1);
   const unsigned ToTypeWidth = TotalWidth / NumElts;
 
@@ -1425,9 +1425,8 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   SelectADDR(Addr, Base, Offset);
 
   Ops.append({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
-              getI32Imm(CodeAddrSpace, DL),
-              getI32Imm(NVPTX::PTXLdStInstCode::Untyped, DL),
-              getI32Imm(ToTypeWidth, DL), Base, Offset, Chain});
+              getI32Imm(CodeAddrSpace, DL), getI32Imm(ToTypeWidth, DL), Base,
+              Offset, Chain});
 
   const MVT::SimpleValueType EltVT =
       ST->getOperand(1).getSimpleValueType().SimpleTy;
@@ -2158,6 +2157,19 @@ bool NVPTXDAGToDAGISel::SelectADDR(SDValue Addr, SDValue &Base,
   return true;
 }
 
+SDValue NVPTXDAGToDAGISel::selectPossiblyImm(SDValue V) {
+  if (V.getOpcode() == ISD::BITCAST)
+    V = V.getOperand(0);
+
+  if (auto *CN = dyn_cast<ConstantSDNode>(V))
+    return CurDAG->getTargetConstant(CN->getAPIntValue(), SDLoc(V),
+                                     V.getValueType());
+  if (auto *CN = dyn_cast<ConstantFPSDNode>(V))
+    return CurDAG->getTargetConstantFP(CN->getValueAPF(), SDLoc(V),
+                                       V.getValueType());
+  return V;
+}
+
 bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N,
                                                  unsigned int spN) const {
   const Value *Src = nullptr;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index ff58e4486a222..0fbcd60a3a3eb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -106,6 +106,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   }
 
   bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
+  SDValue selectPossiblyImm(SDValue V);
 
   bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5979054764647..332a04451065c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -184,6 +184,18 @@ class OneUse2<SDPatternOperator operator>
 class fpimm_pos_inf<ValueType vt>
     : FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
 
+
+
+// Operands which can hold a Register or an Immediate.
+//
+// Unfortunately, since most register classes can hold multiple types, we must
+// use the 'Any' type for these.
+
+def RI1  : Operand<i1>;
+def RI16 : Operand<Any>;
+def RI32 : Operand<Any>;
+def RI64 : Operand<Any>;
+
 // Utility class to wrap up information about a register and DAG type for more
 // convenient iteration and parameterization
 class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
@@ -2338,19 +2350,20 @@ let mayLoad=1, hasSideEffects=0 in {
   def LD_i64 : LD<B64>;
 }
 
-class ST<NVPTXRegClass regclass>
+class ST<DAGOperand O>
   : NVPTXInst<
     (outs),
-    (ins regclass:$src, LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp,
-         LdStCode:$Sign, i32imm:$toWidth, ADDR:$addr),
-    "st${sem:sem}${scope:scope}${addsp:addsp}.${Sign:sign}$toWidth"
+    (ins O:$src,
+         LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, i32imm:$toWidth,
+         ADDR:$addr),
+    "st${sem:sem}${scope:scope}${addsp:addsp}.b$toWidth"
     " \t[$addr], $src;", []>;
 
 let mayStore=1, hasSideEffects=0 in {
-  def ST_i8  : ST<B16>;
-  def ST_i16 : ST<B16>;
-  def ST_i32 : ST<B32>;
-  def ST_i64 : ST<B64>;
+  def ST_i8  : ST<RI16>;
+  def ST_i16 : ST<RI16>;
+  def ST_i32 : ST<RI32>;
+  def ST_i64 : ST<RI64>;
 }
 
 // The following is used only in and after vector elementizations.  Vector
@@ -2386,38 +2399,38 @@ let mayLoad=1, hasSideEffects=0 in {
   defm LDV_i64 : LD_VEC<B64>;
 }
 
-multiclass ST_VEC<NVPTXRegClass regclass, bit support_v8 = false> {
+multiclass ST_VEC<DAGOperand O, bit support_v8 = false> {
   def _v2 : NVPTXInst<
     (outs),
-    (ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
-         LdStCode:$addsp, LdStCode:$Sign, i32imm:$fromWidth,
+    (ins O:$src1, O:$src2,
+         LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, i32imm:$fromWidth,
          ADDR:$addr),
-    "st${sem:sem}${scope:scope}${addsp:addsp}.v2.${Sign:sign}$fromWidth "
+    "st${sem:sem}${scope:scope}${addsp:addsp}.v2.b$fromWidth "
     "\t[$addr], {{$src1, $src2}};", []>;
   def _v4 : NVPTXInst<
     (outs),
-    (ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
-         LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp,
-         LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
-    "st${sem:sem}${scope:scope}${addsp:addsp}.v4.${Sign:sign}$fromWidth "
+    (ins O:$src1, O:$src2, O:$src3, O:$src4,
+         LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, i32imm:$fromWidth,
+         ADDR:$addr),
+    "st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth "
     "\t[$addr], {{$src1, $src2, $src3, $src4}};", []>;
   if support_v8 then
     def _v8 : NVPTXInst<
       (outs),
-      (ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
-           regclass:$src5, regclass:$src6, regclass:$src7, regclass:$src8,
-           LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Sign,
-           i32imm:$fromWidth, ADDR:$addr),
-      "st${sem:sem}${scope:scope}${addsp:addsp}.v8.${Sign:sign}$fromWidth "
+      (ins O:$src1, O:$src2, O:$src3, O:$src4,
+           O:$src5, O:$src6, O:$src7, O:$src8,
+           LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, i32imm:$fromWidth, 
+           ADDR:$addr),
+      "st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth "
       "\t[$addr], "
       "{{$src1, $src2, $src3, $src4, $src5, $src6, $src7, $src8}};", []>;
 }
 
 let mayStore=1, hasSideEffects=0 in {
-  defm STV_i8  : ST_VEC<B16>;
-  defm STV_i16 : ST_VEC<B16>;
-  defm STV_i32 : ST_VEC<B32, support_v8 = true>;
-  defm STV_i64 : ST_VEC<B64>;
+  defm STV_i8  : ST_VEC<RI16>;
+  defm STV_i16 : ST_VEC<RI16>;
+  defm STV_i32 : ST_VEC<RI32, support_v8 = true>;
+  defm STV_i64 : ST_VEC<RI64>;
 }
 
 //---- Conversion ----
diff --git a/llvm/test/CodeGen/NVPTX/access-non-generic.ll b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
index 9edd4de017ee2..601a35288f54d 100644
--- a/llvm/test/CodeGen/NVPTX/access-non-generic.ll
+++ b/llvm/test/CodeGen/NVPTX/access-non-generic.ll
@@ -107,8 +107,7 @@ define void @nested_const_expr() {
 ; PTX-LABEL: nested_const_expr(
   ; store 1 to bitcast(gep(addrspacecast(array), 0, 1))
   store i32 1, ptr getelementptr ([10 x float], ptr addrspacecast (ptr addrspace(3) @array to ptr), i64 0, i64 1), align 4
-; PTX: mov.b32 %r1, 1;
-; PTX-NEXT: st.shared.b32 [array+4], %r1;
+; PTX: st.shared.b32 [array+4], 1;
   ret void
 }
 
diff --git a/llvm/test/CodeGen/NVPTX/chain-different-as.ll b/llvm/test/CodeGen/NVPTX/chain-different-as.ll
index f2d0d9d069ea6..a33d286b47381 100644
--- a/llvm/test/CodeGen/NVPTX/chain-different-as.ll
+++ b/llvm/test/CodeGen/NVPTX/chain-different-as.ll
@@ -4,14 +4,13 @@
 define i64 @test() nounwind readnone {
 ; CHECK-LABEL: test(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    mov.b64 %rd1, 1;
-; CHECK-NEXT:    mov.b64 %rd2, 42;
-; CHECK-NEXT:    st.b64 [%rd1], %rd2;
-; CHECK-NEXT:    ld.global.b64 %rd3, [%rd1];
-; CHECK-NEXT:    st.param.b64 [func_retval0], %rd3;
+; CHECK-NEXT:    st.b64 [%rd1], 42;
+; CHECK-NEXT:    ld.global.b64 %rd2, [%rd1];
+; CHECK-NEXT:    st.param.b64 [func_retval0], %rd2;
 ; CHECK-NEXT:    ret;
   %addr0 = inttoptr i64 1 to ptr
   %addr1 = inttoptr i64 1 to ptr addrspace(1)
diff --git a/llvm/test/CodeGen/NVPTX/demote-vars.ll b/llvm/test/CodeGen/NVPTX/demote-vars.ll
index ab89b62b53d05..e554e4aaea36f 100644
--- a/llvm/test/CodeGen/NVPTX/demote-vars.ll
+++ b/llvm/test/CodeGen/NVPTX/demote-vars.ll
@@ -67,8 +67,7 @@ define void @define_private_global(i64 %val) {
 ; Also check that the if-then is still here, otherwise we may not be testing
 ; the "more-than-one-use" part.
 ; CHECK: st.shared.b64   [private_global_used_more_than_once_in_same_fct],
-; CHECK: mov.b64 %[[VAR:.*]], 25
-; CHECK: st.shared.b64   [private_global_used_more_than_once_in_same_fct], %[[VAR]]
+; CHECK: st.shared.b64   [private_global_used_more_than_once_in_same_fct], 25
 define void @define_private_global_more_than_one_use(i64 %val, i1 %cond) {
   store i64 %val, ptr addrspace(3) @private_global_used_more_than_once_in_same_fct
   br i1 %cond, label %then, label %end
diff --git a/llvm/test/CodeGen/NVPTX/i1-load-lower.ll b/llvm/test/CodeGen/NVPTX/i1-load-lower.ll
index 50d39c88a46b9..5214d272e161f 100644
--- a/llvm/test/CodeGen/NVPTX/i1-load-lower.ll
+++ b/llvm/test/CodeGen/NVPTX/i1-load-lower.ll
@@ -10,14 +10,13 @@ target triple = "nvptx-nvidia-cuda"
 define void @foo() {
 ; CHECK-LABEL: foo(
 ; CHECK:    .reg .pred %p<2>;
-; CHECK:    .reg .b16 %rs<4>;
+; CHECK:    .reg .b16 %rs<3>;
 ; CHECK-EMPTY:
 ; CHECK:    ld.global.b8 %rs1, [i1g];
 ; CHECK:    and.b16 %rs2, %rs1, 1;
 ; CHECK:    setp.ne.b16 %p1, %rs2, 0;
 ; CHECK:    @%p1 bra $L__BB0_2;
-; CHECK:    mov.b16 %rs3, 1;
-; CHECK:    st.global.b8 [i1g], %rs3;
+; CHECK:    st.global.b8 [i1g], 1;
 ; CHECK:    ret;
   %tmp = load i1, ptr addrspace(1) @i1g, align 2
   br i1 %tmp, label %if.end, label %if.then
diff --git a/llvm/test/CodeGen/NVPTX/i128-ld-st.ll b/llvm/test/CodeGen/NVPTX/i128-ld-st.ll
index 6bf65d4d4ad69..abe92a5bf79b9 100644
--- a/llvm/test/CodeGen/NVPTX/i128-ld-st.ll
+++ b/llvm/test/CodeGen/NVPTX/i128-ld-st.ll
@@ -13,8 +13,8 @@ define i128 @foo(ptr %p, ptr %o) {
 ; CHECK-NEXT:    ld.param.b64 %rd2, [foo_param_1];
 ; CHECK-NEXT:    ld.param.b64 %rd1, [foo_param_0];
 ; CHECK-NEXT:    ld.b8 %rd3, [%rd1];
+; CHECK-NEXT:    st.v2.b64 [%rd2], {%rd3, 0};
 ; CHECK-NEXT:    mov.b64 %rd4, 0;
-; CHECK-NEXT:    st.v2.b64 [%rd2], {%rd3, %rd4};
 ; CHECK-NEXT:    st.param.v2.b64 [func_retval0], {%rd3, %rd4};
 ; CHECK-NEXT:    ret;
   %c = load i8, ptr %p, align 1
diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll
index 0718e6d603b6c..e1eeb66b5afc0 100644
--- a/llvm/test/CodeGen/NVPTX/jump-table.ll
+++ b/llvm/test/CodeGen/NVPTX/jump-table.ll
@@ -10,7 +10,7 @@ define void @foo(i32 %i) {
 ; CHECK-LABEL: foo(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .pred %p<2>;
-; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-NEXT:    .reg .b32 %r<3>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %entry
 ; CHECK-NEXT:    ld.param.b32 %r2, [foo_param_0];
@@ -24,20 +24,16 @@ define void @foo(i32 %i) {
 ; CHECK-NEXT:     $L__BB0_5;
 ; CHECK-NEXT:    brx.idx %r2, $L_brx_0;
 ; CHECK-NEXT:  $L__BB0_2: // %case0
-; CHECK-NEXT:    mov.b32 %r6, 0;
-; CHECK-NEXT:    st.global.b32 [out], %r6;
+; CHECK-NEXT:    st.global.b32 [out], 0;
 ; CHECK-NEXT:    bra.uni $L__BB0_6;
 ; CHECK-NEXT:  $L__BB0_4: // %case2
-; CHECK-NEXT:    mov.b32 %r4, 2;
-; CHECK-NEXT:    st.global.b32 [out], %r4;
+; CHECK-NEXT:    st.global.b32 [out], 2;
 ; CHECK-NEXT:    bra.uni $L__BB0_6;
 ; CHECK-NEXT:  $L__BB0_5: // %case3
-; CHECK-NEXT:    mov.b32 %r3, 3;
-; CHECK-NEXT:    st.global.b32 [out], %r3;
+; CHECK-NEXT:    st.global.b32 [out], 3;
 ; CHECK-NEXT:    bra.uni $L__BB0_6;
 ; CHECK-NEXT:  $L__BB0_3: // %case1
-; CHECK-NEXT:    mov.b32 %r5, 1;
-; CHECK-NEXT:    st.global.b32 [out], %r5;
+; CHECK-NEXT:    st.global.b32 [out], 1;
 ; CHECK-NEXT:  $L__BB0_6: // %end
 ; CHECK-NEXT:    ret;
 entry:
diff --git a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
index 2bfd891a04a17..aaf71dd9884c6 100644
--- a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
+++ b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
@@ -144,7 +144,7 @@ define void @foo4() {
 ; PTX32-NEXT:    .local .align 4 .b8 __local_depot3[8];
 ; PTX32-NEXT:    .reg .b32 %SP;
 ; PTX32-NEXT:    .reg .b32 %SPL;
-; PTX32-NEXT:    .reg .b32 %r<6>;
+; PTX32-NEXT:    .reg .b32 %r<5>;
 ; PTX32-EMPTY:
 ; PTX32-NEXT:  // %bb.0:
 ; PTX32-NEXT:    mov.b32 %SPL, __local_depot3;
@@ -153,9 +153,8 @@ define void @foo4() {
 ; PTX32-NEXT:    add.u32 %r2, %SPL, 0;
 ; PTX32-NEXT:    add.u32 %r3, %SP, 4;
 ; PTX32-NEXT:    add.u32 %r4, %SPL, 4;
-; PTX32-NEXT:    mov.b32 %r5, 0;
-; PTX32-NEXT:    st.local.b32 [%r2], %r5;
-; PTX32-NEXT:    st.local.b32 [%r4], %r5;
+; PTX32-NEXT:    st.local.b32 [%r2], 0;
+; PTX32-NEXT:    st.local.b32 [%r4], 0;
 ; PTX32-NEXT:    { // callseq 1, 0
 ; PTX32-NEXT:    .param .b32 param0;
 ; PTX32-NEXT:    st.param.b32 [param0], %r1;
@@ -181,7 +180,6 @@ define void @foo4() {
 ; PTX64-NEXT:    .local .align 4 .b8 __local_depot3[8];
 ; PTX64-NEXT:    .reg .b64 %SP;
 ; PTX64-NEXT:    .reg .b64 %SPL;
-; PTX64-NEXT:    .reg .b32 %r<2>;
 ; PTX64-NEXT:    .reg .b64 %rd<5>;
 ; PTX64-EMPTY:
 ; PTX64-NEXT:  // %bb.0:
@@ -191,9 +189,8 @@ define void @foo4() {
 ; PTX64-NEXT:    add.u64 %rd2, %SPL, 0;
 ; PTX64-NEXT:    add.u64 %rd3, %SP, 4;
 ; PTX64-NEXT:    add.u64 %rd4, %SPL, 4;
-; PTX64-NEXT:    mov.b32 %r1, 0;
-; PTX64-NEXT:    st.local.b32 [%rd2], %r1;
-; PTX64-NEXT:    st.local.b32 [%rd4], %r1;
+; PTX64-NEXT:    st.local.b32 [%rd2], 0;
+; PTX64-NEXT:    st.local.b32 [%rd4], 0;
 ; PTX64-NEXT:    { // callseq 1, 0
 ; PTX64-NEXT:    .param .b64 param0;
 ; PTX64-NEXT:    st.param.b64 [param0], %rd1;
diff --git a/llvm/test/CodeGen/NVPTX/lower-alloca.ll b/llvm/test/CodeGen/NVPTX/lower-alloca.ll
index 489bcf4a7d55c..57c1e5826c89a 100644
--- a/llvm/test/CodeGen/NVPTX/lower-alloca.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-alloca.ll
@@ -15,7 +15,7 @@ define ptx_kernel void @kernel() {
 ; LOWERALLOCAONLY: [[V1:%.*]] = addrspacecast ptr %A to ptr addrspace(5)
 ; LOWERALLOCAONLY: [[V2:%.*]] = addrspacecast ptr addrspace(5) [[V1]] to ptr
 ; LOWERALLOCAONLY: store i32 0, ptr [[V2]], align 4
-; PTX: st.local.b32 [{{%rd[0-9]+}}], {{%r[0-9]+}}
+; PTX: st.local.b32 [{{%rd[0-9]+}}], 0
   store i32 0, ptr %A
   call void @callee(ptr %A)
   ret void
@@ -26,7 +26,7 @@ define void @alloca_in_explicit_local_as() {
 ; PTX-LABEL: .visible .func alloca_in_explicit_local_as(
   %A = alloca i32, addrspace(5)
 ; CHECK: store i32 0, ptr addrspace(5) {{%.+}}
-; PTX: st.local.b32 [%SP], {{%r[0-9]+}}
+; PTX: st.local.b32 [%SP], 0
 ; LOWERALLOCAONLY: [[V1:%.*]] = addrspacecast ptr addrspace(5) %A to ptr
 ; LOWERALLOCAONLY: store i32 0, ptr [[V1]], align 4
   store i32 0, ptr addrspace(5) %A
diff --git a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
index 54495cf0d61f3..0c2c39e4d5246 100644
--- a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
@@ -658,7 +658,7 @@ define ptx_kernel void @test_select_write(ptr byval(i32) align 4 %input1, ptr by
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .pred %p<2>;
 ; PTX-NEXT:    .reg .b16 %rs<3>;
-; PTX-NEXT:    .reg .b32 %r<4>;
+; PTX-NEXT:    .reg .b32 %r<3>;
 ; PTX-NEXT:    .reg .b64 %rd<6>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0: // %bb
@@ -674,8 +674,7 @@ define ptx_kernel void @test_select_write(ptr byval(i32) align 4 %input1, ptr by
 ; PTX-NEXT:    add.u64 %rd2, %SPL, 4;
 ; PTX-NEXT:    add.u64 %rd4, %SPL, 0;
 ; PTX-NEXT:    selp.b64 %rd5, %rd2, %rd4, %p1;
-; PTX-NEXT:    mov.b32 %r3, 1;
-; PTX-NEXT:    st.local.b32 [%rd5], %r3;
+; PTX-NEXT:    st.local.b32 [%rd5], 1;
 ; PTX-NEXT:    ret;
 bb:
   %ptrnew = select i1 %cond, ptr %input1, ptr %input2
@@ -838,7 +837,7 @@ define ptx_kernel void @test_phi_write(ptr byval(%struct.S) align 4 %input1, ptr
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .pred %p<2>;
 ; PTX-NEXT:    .reg .b16 %rs<3>;
-; PTX-NEXT:    .reg .b32 %r<4>;
+; PTX-NEXT:    .reg .b32 %r<3>;
 ; PTX-NEXT:    .reg .b64 %rd<7>;
 ; PTX-EMPTY:
 ; PTX-NEXT:  // %bb.0: // %bb
@@ -857,8 +856,7 @@ define ptx_kernel void @test_phi_write(ptr byval(%struct.S) align 4 %input1, ptr
 ; PTX-NEXT:  // %bb.1: // %second
 ; PTX-NEXT:    mov.b64 %rd6, %rd1;
 ; PTX-NEXT:  $L__BB14_2: // %merge
-; PTX-NEXT:    mov.b32 %r3, 1;
-; PTX-NEXT:    st.local.b32 [%rd6], %r3;
+; PTX-NEXT:    st.local.b32 [%rd6], 1;
 ; PTX-NEXT:    ret;
 bb:
   br i1 %cond, label %first, label %second
diff --git a/llvm/test/CodeGen/NVPTX/param-align.ll b/llvm/test/CodeGen/NVPTX/param-align.ll
index 16220fb4d47bb..c85080fdf295a 100644
--- a/llvm/test/CodeGen/NVPTX/param-align.ll
+++ b/llvm/test/CodeGen/NVPTX/param-align.ll
@@ -73,12 +73,10 @@ define ptx_device void @t6() {
 ; CHECK-LABEL: .func check_ptr_align1(
 ; CHECK: 	ld.param.b64 	%rd1, [check_ptr_align1_param_0];
 ; CHECK-NOT: 	ld.param.b8
-; CHECK: 	mov.b32 	%r1, 0;
-; CHECK: 	st.b8 	[%rd1+3], %r1;
-; CHECK: 	st.b8 	[%rd1+2], %r1;
-; CHECK: 	st.b8 	[%rd1+1], %r1;
-; CHECK: 	mov.b32 	%r2, 1;
-; CHECK: 	st.b8 	[%rd1], %r2;
+; CHECK: 	st.b8 	[%rd1+3], 0;
+; CHECK: 	st.b8 	[%rd1+2], 0;
+; CHECK: 	st.b8 	[%rd1+1], 0;
+; CHECK: 	st.b8 	[%rd1], 1;
 ; CHECK: 	ret;
 define void @check_ptr_align1(ptr align 1 %_arg_ptr) {
 entry:
@@ -89,10 +87,8 @@ entry:
 ; CHECK-LABEL: .func check_ptr_align2(
 ; CHECK: 	ld.param.b64 	%rd1, [check_ptr_align2_param_0];
 ; CHECK-NOT: 	ld.param.b16
-; CHECK: 	mov.b32 	%r1, 0;
-; CHECK: 	st.b16 	[%rd1+2], %r1;
-; CHECK: 	mov.b32 	%r2, 2;
-; CHECK: 	st.b16 	[%rd1], %r2;
+; CHECK: 	st.b16 	[%rd1+2], 0;
+; CHECK: 	st.b16 	[%rd1], 2;
 ; CHECK: 	ret;
 define void @check_ptr_align2(ptr align 2 %_arg_ptr) {
 entry:
@@ -103,8 +99,7 @@ entry:
 ; CHECK-LABEL: .func check_ptr_align4(
 ; CHECK: 	ld.param.b64 	%rd1, [check_ptr_align4_param_0];
 ; CHECK-NOT: 	ld.param.b32
-; CHECK: 	mov.b32 	%r1, 4;
-; CHECK: 	st.b32 	[%rd1], %r1;
+; CHECK: 	st.b32 	[%rd1], 4;
 ; CHECK: 	ret;
 define void @check_ptr_align4(ptr align 4 %_arg_ptr) {
 entry:
@@ -114,8 +109,7 @@ entry:
 
 ; CHECK-LABEL: .func check_ptr_align8(
 ; CHECK: 	ld.param.b64 	%rd1, [check_ptr_align8_param_0];
-; CHECK: 	mov.b32 	%r1, 8;
-; CHECK: 	st.b32 	[%rd1], %r1;
+; CHECK: 	st.b32 	[%rd1], 8;
 ; CHECK: 	ret;
 define void @check_ptr_align8(ptr align 8 %_arg_ptr) {
 entry:
diff --git a/llvm/test/CodeGen/NVPTX/pr13291-i1-store.ll b/llvm/test/CodeGen/NVPTX/pr13291-i1-store.ll
index b6f1964c54c76..cd2505c20d39c 100644
--- a/llvm/test/CodeGen/NVPTX/pr13291-i1-store.ll
+++ b/llvm/test/CodeGen/NVPTX/pr13291-i1-store.ll
@@ -4,10 +4,8 @@
 ; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify %}
 
 define ptx_kernel void @t1(ptr %a) {
-; PTX32:      mov.b16 %rs{{[0-9]+}}, 0;
-; PTX32-NEXT: st.global.b8 [%r{{[0-9]+}}], %rs{{[0-9]+}};
-; PTX64:      mov.b16 %rs{{[0-9]+}}, 0;
-; PTX64-NEXT: st.global.b8 [%rd{{[0-9]+}}], %rs{{[0-9]+}};
+; PTX32:      st.global.b8 [%r{{[0-9]+}}], 0;
+; PTX64:      st.global.b8 [%rd{{[0-9]+}}], 0;
   store i1 false, ptr %a
   ret void
 }
diff --git a/llvm/test/CodeGen/NVPTX/reg-types.ll b/llvm/test/CodeGen/NVPTX/reg-types.ll
index fb065e1b01bbe..ea45bfdc5e190 100644
--- a/llvm/test/CodeGen/NVPTX/reg-types.ll
+++ b/llvm/test/CodeGen/NVPTX/reg-types...
[truncated]

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the patch intended to cover v16/v2f16 ? It does not seem to have any effect on f16 instructions in the tests.

@AlexMaclean
Copy link
Member Author

Is the patch intended to cover v16/v2f16 ? It does not seem to have any effect on f16 instructions in the tests.

There are no instructions that we could fold in those tests but this change does cover v16/v2f16.

Copy link
Contributor

@dakersnar dakersnar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice simplification, PTX looks much more readable.

Copy link
Member

Artem-B commented Jun 24, 2025

Is the patch intended to cover v16/v2f16 ? It does not seem to have any effect on f16 instructions in the tests.

We do have some tests for f16 instructions, but they still use manual checks and may not cover the changes in

define half @test_fadd_imm_0(half %b) #0 {
%r = fadd half 1.0, %b
ret half %r

; Check that we can lower fadd with immediate arguments.
; CHECK-LABEL: test_fadd_imm_0(
; CHECK-DAG:  ld.param.b16    [[B:%rs[0-9]+]], [test_fadd_imm_0_param_0];
...
; CHECK-F16-FTZ-DAG:    mov.b16        [[A:%rs[0-9]+]], 0x3C00;
; CHECK-F16-FTZ-NEXT:   add.rn.ftz.f16     [[R:%rs[0-9]+]], [[B]], [[A]];
;...
 CHECK-NEXT: st.param.b16    [func_retval0], [[R]];
; CHECK-NEXT: ret;
define half @test_fadd_imm_0(half %b) #0 {
  %r = fadd half 1.0, %b
  ret half %r
}

Looks like the patch didn't allow direct use of const for f16.

@AlexMaclean
Copy link
Member Author

Looks like the patch didn't allow direct use of const for f16.

This change is only changing the handling of operands in st instructions while that case is an add instruction. Its not expected/desired that this change would have any impact on that test case.

@Artem-B
Copy link
Member

Artem-B commented Jun 25, 2025

This change is only changing the handling of operands in st instructions while that case is an add instruction. Its not expected/desired that this change would have any impact on that test case.

OK, so it is due to the lack of tests after all. Do you mind adding a handful of test cases for f16/v2f16?

Other than that, the patch LGTM.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/st-imm branch from a9fddf3 to 336b4b9 Compare June 26, 2025 00:44
@AlexMaclean AlexMaclean merged commit 16e712e into llvm:main Jun 26, 2025
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants