Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String);
TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array<ffi::String>);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDataRaceCheck, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTG, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableLowerLDGSTGPredicated, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTGPredicated, Bool);

DataType cuTensorMapType() { return DataType::UInt(8, 128); }

Expand Down
12 changes: 6 additions & 6 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,17 @@ static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kEnableLowerLDGSTG = "tl.enable_lower_ldgstg";

/*!
* \brief Disable lowering predicated global load/store to ldg/stg intrinsics
* \brief Enable lowering predicated global load/store to ldg/stg intrinsics
*
* When enabled (set to true), predicated loads (if_then_else with else=0) and
* predicated stores (IfThenElse with store in then case) will NOT be lowered
* predicated stores (IfThenElse with store in then case) will be lowered
* to predicated ldg/stg intrinsics.
* Default: OFF (predicated lowering is enabled by default)
* Default: OFF (predicated lowering is disabled by default)
*
* kDisableLowerLDGSTGPredicated = "tl.disable_lower_ldgstg_predicated"
* kEnableLowerLDGSTGPredicated = "tl.enable_lower_ldgstg_predicated"
*/
static constexpr const char *kDisableLowerLDGSTGPredicated =
"tl.disable_lower_ldgstg_predicated";
static constexpr const char *kEnableLowerLDGSTGPredicated =
"tl.enable_lower_ldgstg_predicated";
static constexpr const char *kStorageRewriteDetectInplace =
"tl.storage_rewrite_detect_inplace";
static constexpr const char *kASTPrintEnable = "tl.ast_print_enable";
Expand Down
10 changes: 4 additions & 6 deletions src/transform/lower_ldg_stg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* Pass configurations:
* - tl.enable_lower_ldgstg: Enable non-predicated ldg/stg lowering (default:
* OFF)
* - tl.disable_lower_ldgstg_predicated: Disable predicated ldg/stg lowering
* - tl.enable_lower_ldgstg_predicated: Enable predicated ldg/stg lowering
* (default: OFF)
*/

Expand Down Expand Up @@ -491,11 +491,9 @@ tvm::transform::Pass LowerLDGSTG() {
// Non-predicated ldg/stg: default OFF
bool enable_non_predicated =
ctx->GetConfig<Bool>(kEnableLowerLDGSTG, Bool(false)).value();
// Predicated ldg/stg: default ON (so disable flag default is false)
bool disable_predicated =
ctx->GetConfig<Bool>(kDisableLowerLDGSTGPredicated, Bool(false))
.value();
bool enable_predicated = !disable_predicated;
// Predicated ldg/stg: default OFF
bool enable_predicated =
ctx->GetConfig<Bool>(kEnableLowerLDGSTGPredicated, Bool(false)).value();

// If both are disabled, skip the pass entirely
if (!enable_non_predicated && !enable_predicated) {
Expand Down
18 changes: 9 additions & 9 deletions testing/python/transform/test_tilelang_transform_lower_ldgstg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

Pass configurations:
- tl.enable_lower_ldgstg: Enable non-predicated ldg/stg lowering (default: OFF)
- tl.disable_lower_ldgstg_predicated: Disable predicated ldg/stg lowering (default: OFF)
- tl.enable_lower_ldgstg_predicated: Enable predicated ldg/stg lowering (default: OFF)
"""

from tilelang import tvm as tvm
Expand All @@ -14,15 +14,15 @@
from tvm import tir


def _apply_passes(mod, enable_non_predicated=False, disable_predicated=False):
def _apply_passes(mod, enable_non_predicated=False, enable_predicated=False):
"""Apply the LowerLDGSTG pass and related lowering passes."""
mod = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))(mod)
mod = tl.transform.FlattenBuffer()(mod)
mod = tl.transform.VectorizeLoop()(mod)
with tvm.transform.PassContext(
config={
PassConfigKey.TL_ENABLE_LOWER_LDGSTG: enable_non_predicated,
PassConfigKey.TL_DISABLE_LOWER_LDGSTG_PREDICATED: disable_predicated,
PassConfigKey.TL_ENABLE_LOWER_LDGSTG_PREDICATED: enable_predicated,
}
):
mod = tl.transform.LowerLDGSTG()(mod)
Expand Down Expand Up @@ -135,7 +135,7 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), pred: T
B[i] = T.if_then_else(pred > 0, A[i], T.float32(0))

mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = _apply_passes(mod) # Default: predicated is ON
mod = _apply_passes(mod, enable_predicated=True) # Default: predicated is ON
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Misleading inline comment.

The comment says "Default: predicated is ON" but with the new semantics, the default is OFF. The test correctly passes enable_predicated=True to enable it, but the comment should be updated.

📝 Suggested fix
-    mod = _apply_passes(mod, enable_predicated=True)  # Default: predicated is ON
+    mod = _apply_passes(mod, enable_predicated=True)  # Explicitly enable predicated lowering

The same applies to lines 155, 172, and 190.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mod = _apply_passes(mod, enable_predicated=True) # Default: predicated is ON
mod = _apply_passes(mod, enable_predicated=True) # Explicitly enable predicated lowering
🤖 Prompt for AI Agents
In `@testing/python/transform/test_tilelang_transform_lower_ldgstg.py` at line
138, Update the misleading inline comments that state "Default: predicated is
ON" next to calls to _apply_passes(..., enable_predicated=True); the new
semantics default predication to OFF, so change those comments (the ones
adjacent to the _apply_passes calls referencing enable_predicated=True) to
reflect "Default: predicated is OFF" or remove the misleading default note so
the comment matches the actual behavior.

print("=== test_lower_ldg32_predicated ===")
print(mod)
assert _check_has_intrinsic(mod, "ldg32"), "Expected predicated ldg32"
Expand All @@ -152,7 +152,7 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), pred: T
B[i] = A[i]

mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = _apply_passes(mod) # Default: predicated is ON
mod = _apply_passes(mod, enable_predicated=True) # Default: predicated is ON
print("=== test_lower_stg32_predicated ===")
print(mod)
assert _check_has_intrinsic(mod, "stg32"), "Expected predicated stg32"
Expand All @@ -169,7 +169,7 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), pred: T
B[i * 4 + j] = T.if_then_else(pred > 0, A[i * 4 + j], T.float32(0))

mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = _apply_passes(mod) # Default: predicated is ON
mod = _apply_passes(mod, enable_predicated=True) # Default: predicated is ON
print("=== test_lower_ldg128_predicated ===")
print(mod)
assert _check_has_intrinsic(mod, "ldg128"), "Expected predicated ldg128"
Expand All @@ -187,7 +187,7 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), pred: T
B[i * 4 + j] = A[i * 4 + j]

mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = _apply_passes(mod) # Default: predicated is ON
mod = _apply_passes(mod, enable_predicated=True) # Default: predicated is ON
print("=== test_lower_stg128_predicated ===")
print(mod)
assert _check_has_intrinsic(mod, "stg128"), "Expected predicated stg128"
Expand All @@ -204,7 +204,7 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), N: T.in
B[idx] = T.if_then_else(idx < N, A[idx], T.float32(0))

mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = _apply_passes(mod, disable_predicated=True)
mod = _apply_passes(mod, enable_predicated=False)
print("=== test_predicated_disabled ===")
print(mod)
# When disabled, no predicated ldg/stg should be generated
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_e2e_load_global_store_global_predicated():
"""End-to-end test that load_global/store_global intrinsics work correctly when enabled."""
import torch

@tilelang.jit(pass_configs={PassConfigKey.TL_ENABLE_LOWER_LDGSTG: True})
@tilelang.jit(pass_configs={PassConfigKey.TL_ENABLE_LOWER_LDGSTG: True, PassConfigKey.TL_ENABLE_LOWER_LDGSTG_PREDICATED: True})
def copy_kernel(X, Y):
N = T.const("N")
X: T.Tensor[[N], T.float32]
Expand Down
6 changes: 3 additions & 3 deletions tilelang/transform/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class PassConfigKey(str, Enum):
When enabled, converts Ramp-based global buffer load/store to ldg/stg intrinsics.
Default: False"""

TL_DISABLE_LOWER_LDGSTG_PREDICATED = "tl.disable_lower_ldgstg_predicated"
"""Disable predicated LDG/STG lowering.
When False (default), predicated loads (if_then_else with else=0) and
TL_ENABLE_LOWER_LDGSTG_PREDICATED = "tl.enable_lower_ldgstg_predicated"
"""Enable predicated LDG/STG lowering.
When True, predicated loads (if_then_else with else=0) and
predicated stores (IfThenElse with empty then case) are lowered to
ldg/stg intrinsics. Default: False"""

Expand Down
Loading