Skip to content

Conversation

@clouds56
Copy link
Contributor

@clouds56 clouds56 commented Dec 26, 2025

Summary by CodeRabbit

  • New Features

    • Unified, cross-backend profiling: consistent timing and synchronization across CUDA, MPS, and CPU with automatic device selection and CPU fallback.
    • Metal backend: added ability to retrieve kernel source for inspection or diagnostics.
  • Bug Fixes

    • More robust kernel caching: gracefully skips backends that do not produce on-disk cacheable artifacts and avoids erroneous write attempts.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 26, 2025

📝 Walkthrough

Walkthrough

Adds defensive handling when kernel adapters lack on-disk artifacts, exposes Metal kernel source retrieval via a new adapter method, and implements cross-backend (CUDA/MPS/CPU) timing in the profiler with unified Event abstraction, centralized synchronization, and device selection.

Changes

Cohort / File(s) Summary
Kernel cache / Metal adapter
tilelang/cache/kernel_cache.py, tilelang/jit/adapter/torch/metal.py
_save_so_cubin_to_disk now obtains libpath via getattr(..., "libpath", None) and returns early when absent. MetalKernelAdapter sets libpath = None and adds get_kernel_source(self, kernel_only: bool = False) -> str.
Profiler: cross-backend timing
tilelang/profiler/bench.py
Introduced CPU timing fallback (CpuEvent), IS_MPS detection and device selection (cuda/mps/cpu), unified Event usage replacing direct torch.cuda.Event, _synchronize() helper for backend-aware sync, and refactored timing functions to use the new abstractions and device variable.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant Bench as bench.py
    participant Device
    participant Timer as Event
    participant Sync as _synchronize

    rect rgb(220,235,245)
    Note over Caller,Bench: Start benchmark → detect backend
    Caller->>Bench: do_bench(...)
    Bench->>Device: detect (cuda / mps / cpu)
    Bench->>Bench: set `device`, set `IS_MPS`
    end

    rect rgb(250,240,220)
    Note over Caller,Timer: Warmup (record + run + sync)
    Caller->>Timer: Event.record(start)
    Caller->>Device: execute warmup on `device`
    Caller->>Sync: _synchronize()
    Sync->>Device: backend-aware sync
    end

    rect rgb(230,245,230)
    Note over Caller,Timer: Timed iterations (record, run, record, sync)
    loop iterations
      Caller->>Timer: Event.record(start)
      Caller->>Device: run iteration on `device`
      Caller->>Timer: Event.record(end)
      Caller->>Sync: _synchronize()
    end
    Timer->>Caller: elapsed_time() -> duration
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐇 I hopped from CUDA through MPS to CPU,

Saved sources when binaries were few,
Timers now listen on every shore,
Syncs keep the beats, and benchmarks soar.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title 'improve benchmark on mps' is partially related to the changeset but lacks specificity. While the PR does improve benchmarking for MPS (Metal), the title is vague and doesn't clearly convey the main technical change—adding cross-backend profiling support with CPU/CUDA/MPS fallback logic. Consider a more descriptive title like 'Add cross-backend profiling support for CUDA, MPS, and CPU' or 'Enable benchmark profiling on MPS with fallback to CPU' to better communicate the scope and technical intent.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 875b42f and 7d2fd4b.

📒 Files selected for processing (3)
  • tilelang/cache/kernel_cache.py
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/profiler/bench.py
🔇 Additional comments (5)
tilelang/jit/adapter/torch/metal.py (1)

29-31: LGTM! Clear documentation of Metal backend behavior.

The comment effectively explains why libpath is set to None — Metal kernels are compiled at runtime via torch.mps.compile_shader without producing a cacheable shared library. This aligns with the defensive getattr(..., "libpath", None) check in kernel_cache.py.

tilelang/cache/kernel_cache.py (1)

410-414: Defensive handling for backends without cacheable binaries.

The getattr with None default correctly avoids AttributeError for adapters missing libpath, and the early return prevents attempting to cache non-existent artifacts.

However, note that _get_required_files() (lines 419-422) still includes kernel_lib_path as a required file for cache loading. This means Metal kernels will never load from disk cache since the .so file won't exist. If disk caching is intended for Metal in the future (e.g., caching the source code), consider making the library file optional in _get_required_files() when the backend doesn't produce one.

tilelang/profiler/bench.py (3)

107-111: LGTM! Clean synchronization abstraction.

The _synchronize() helper correctly dispatches to the appropriate backend. The implicit no-op for CPU (no else branch) is intentional and correct since CPU operations are synchronous.


151-168: Correct usage of new timing abstractions.

The benchmark setup properly uses _synchronize() and the new Event class. The cache buffer allocation now correctly uses the detected device variable instead of hardcoding CUDA.


196-207: Event-based timing correctly uses the new abstraction.

The per-iteration event recording and final synchronization via _synchronize() is appropriate. The global sync at line 207 ensures all events complete before collecting elapsed times.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tilelang/profiler/bench.py (2)

97-105: Device detection logic correct, but must move before Event class.

The priority order (CUDA → MPS → CPU) is appropriate, and the hasattr check for torch.backends.mps ensures compatibility with older PyTorch versions.

As noted above, move these definitions before the Event class definition to resolve the forward reference error.


60-95: Critical: Forward reference error still present.

The Event class references IS_CUDA (line 65) and IS_MPS (line 67) before they are defined at lines 97-98. This will cause a NameError at runtime when the class is instantiated.

Move the definitions before the class:

+IS_CUDA = torch.cuda.is_available()
+IS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+
 class Event:

Then remove the duplicate definitions at lines 97-98.

Additional concerns:

  1. Docstring is understated (line 61): The description "Dummy event class for CPU timing compatibility" does not accurately reflect that this is a full cross-backend timing abstraction for CUDA, MPS, and CPU.

  2. Silent fallback (line 94): Returning 0.0 when timing is unavailable could mask failures. Consider logging a warning to make timing issues visible during debugging.

🧹 Nitpick comments (2)
tilelang/profiler/bench.py (2)

187-207: Consider renaming _bench_with_cuda_events for clarity.

The function implementation correctly uses the new Event wrapper and _synchronize() for cross-backend support. However, the function name _bench_with_cuda_events is now misleading since it supports CUDA, MPS, and CPU.

Consider renaming to _bench_with_events or _bench_with_timing_events to reflect its broader scope.

🔎 Proposed rename
-def _bench_with_cuda_events(
+def _bench_with_events(
     fn: Callable,
     cache: torch.Tensor,
     n_repeat: int,
     quantiles: list[float] | None,
     return_mode: str,
 ) -> float | list[float]:
-    """Benchmark using CUDA events for timing."""
+    """Benchmark using Event-based timing across backends."""

And update the call site at line 180:

-        return _bench_with_cuda_events(fn, cache, n_repeat, quantiles, return_mode)
+        return _bench_with_events(fn, cache, n_repeat, quantiles, return_mode)

157-157: Consider conditionally allocating the L2 cache flush buffer only for CUDA.

The cache buffer is now created on all backends (line 157), but the L2 cache flushing technique is specific to NVIDIA GPUs. On MPS, the different cache hierarchy means cache.zero_() may not provide the intended consistent cache state. On CPU, automatic cache management means zeroing a buffer adds unnecessary memory pressure without providing deterministic cache flushing.

Either document why the cache buffer benefits MPS and CPU, or allocate it only when IS_CUDA is true.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7d2fd4b and e422509.

📒 Files selected for processing (1)
  • tilelang/profiler/bench.py
🔇 Additional comments (4)
tilelang/profiler/bench.py (4)

7-7: LGTM: Import supports CPU timing fallback.

The time module is correctly imported to enable perf_counter_ns() for CPU timing in the Event class.


107-111: LGTM: Clean synchronization abstraction.

The _synchronize() helper centralizes backend-specific synchronization logic and eliminates repetitive conditionals throughout the codebase.


151-151: LGTM: Consistent use of cross-backend synchronization.

Correctly replaced direct CUDA synchronization with the new _synchronize() helper.


160-167: LGTM: Correct migration to cross-backend timing.

The timing estimation correctly uses the new Event wrapper and _synchronize() helper, maintaining the same logic while enabling MPS and CPU support.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/profiler/bench.py (1)

187-194: Update the docstring to reflect cross-backend support.

The docstring "Benchmark using CUDA events for timing" is outdated. The function now uses the cross-backend Event abstraction that supports CUDA, MPS, and CPU timing.

🔎 Proposed fix
 def _bench_with_events(
     fn: Callable,
     cache: torch.Tensor,
     n_repeat: int,
     quantiles: list[float] | None,
     return_mode: str,
 ) -> float | list[float]:
-    """Benchmark using CUDA events for timing."""
+    """Benchmark using cross-backend timing events (CUDA/MPS/CPU)."""
♻️ Duplicate comments (1)
tilelang/profiler/bench.py (1)

70-71: Update the docstring to reflect cross-backend support.

The docstring "Dummy event class for CPU timing compatibility" understates the purpose of this class. It's actually a full cross-backend timing abstraction that unifies CUDA, MPS, and CPU timing, not just a dummy placeholder.

🔎 Proposed fix
-class Event:
-    """Dummy event class for CPU timing compatibility."""
+class Event:
+    """Cross-backend timing event abstraction for CUDA, MPS, and CPU."""
🧹 Nitpick comments (2)
tilelang/jit/adapter/torch/metal.py (1)

76-77: Consider handling the kernel_only parameter explicitly or documenting its omission.

The kernel_only parameter is currently unused (flagged by static analysis), and the default value differs from the base class (False vs. True). While Metal may only have a single source variant (similar to CuTeDSL), the implementation could be clearer.

Consider one of these approaches:

  1. Explicit handling (even if both branches return the same value):

    def get_kernel_source(self, kernel_only: bool = True) -> str:
        # Metal only has kernel_global_source for both host and device
        return self.kernel_global_source or ""
  2. Document the intentional no-op if Metal truly doesn't distinguish:

    def get_kernel_source(self, kernel_only: bool = True) -> str:
        # Metal backend has unified source; kernel_only is ignored
        return self.kernel_global_source or ""

Also, align the default value with BaseKernelAdapter (True instead of False) for consistency across adapters.

Based on learnings, CuTeDSL similarly has unified source, but being explicit about the behavior improves maintainability.

tilelang/profiler/bench.py (1)

98-104: Consider warning when timing is unavailable.

The method silently returns 0.0 when timing is unavailable (line 104). This could hide timing failures and make debugging difficult. Consider logging a warning to alert users that timing data is not available.

🔎 Proposed fix

Add logging import at the top of the file:

+import logging
 import os
 import sys

Then update the elapsed_time method:

     def elapsed_time(self, end_event: Event) -> float:
         if self.inner is not None and end_event.inner is not None:
             return self.inner.elapsed_time(end_event.inner)  # type: ignore
         elif self.record_time is not None and end_event.record_time is not None:
             return (end_event.record_time - self.record_time) / 1e6  # Convert ns to ms
         else:
+            logging.warning("Timing unavailable: events were not properly recorded")
             return 0.0
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e422509 and 1c8b041.

📒 Files selected for processing (2)
  • tilelang/jit/adapter/torch/metal.py
  • tilelang/profiler/bench.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-26T06:45:47.669Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:47.669Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.

Applied to files:

  • tilelang/jit/adapter/torch/metal.py
🧬 Code graph analysis (1)
tilelang/jit/adapter/torch/metal.py (4)
tilelang/jit/adapter/cython/adapter.py (2)
  • libpath (366-368)
  • get_kernel_source (380-387)
tilelang/jit/kernel.py (1)
  • get_kernel_source (438-449)
tilelang/jit/adapter/nvrtc/adapter.py (1)
  • get_kernel_source (179-190)
tilelang/jit/adapter/base.py (1)
  • get_kernel_source (89-93)
🪛 Ruff (0.14.10)
tilelang/jit/adapter/torch/metal.py

76-76: Unused method argument: kernel_only

(ARG002)

tilelang/profiler/bench.py

184-184: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (9)
tilelang/jit/adapter/torch/metal.py (1)

29-31: Good defensive handling for Metal's artifact-less compilation.

The explicit libpath = None with a clear explanation prevents issues when the kernel cache attempts to retrieve the library path. This aligns well with Metal's shader compilation model via torch.mps.compile_shader.

tilelang/profiler/bench.py (8)

7-7: LGTM!

The time import is correctly added to support CPU timing fallback via perf_counter_ns() in the Event class.


60-67: LGTM! Forward reference issue resolved.

The device detection logic is correct and properly ordered (CUDA → MPS → CPU). The hasattr check ensures compatibility with older PyTorch versions. Moving these definitions before the Event class successfully resolves the previous forward reference error.


107-111: LGTM!

The _synchronize() helper correctly centralizes synchronization logic across backends. The ordering (CUDA → MPS → implicit CPU) is appropriate since CPU execution is synchronous by default.


151-151: LGTM!

Correctly replaced torch.cuda.synchronize() with the new cross-backend _synchronize() helper.


160-161: LGTM!

The timing estimation correctly uses the new Event wrapper and _synchronize() helper for cross-backend compatibility.

Also applies to: 167-167


180-180: LGTM!

Correctly updated to call the renamed _bench_with_events function.


196-197: LGTM!

The function correctly uses the new Event wrapper for cross-backend timing and _synchronize() for proper synchronization across all backends.

Also applies to: 207-207


88-96: Reconsider using per-event synchronization instead of global MPS synchronization.

The code currently uses torch.mps.synchronize() (global device sync) as the primary approach, with a comment claiming per-event synchronization can hang. However, PyTorch MPS documentation actually recommends the opposite: prefer torch.mps.event.Event for fine-grained synchronization when you only need to wait for specific work. Global torch.mps.synchronize() should be reserved as a fallback, as it can hide deadlocks and hangs during kernel errors and is less efficient than event-based synchronization. Consider using event.synchronize() instead to only block on the specific event that needs completion.

Likely an incorrect or invalid review comment.

cache_size = int(256e6 // 4) if fast_flush else int(256e6)
cache_dtype = torch.int if fast_flush else torch.int8
cache = torch.empty(cache_size, dtype=cache_dtype, device="cuda")
cache = torch.empty(cache_size, dtype=cache_dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n timelang/profiler/bench.py | head -250 | tail -100

Repository: tile-ai/tilelang

Length of output: 118


🏁 Script executed:

# Find the bench.py file first
fd bench.py

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

wc -l tilelang/profiler/bench.py

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

# Read the section around line 157 and check cache operations
sed -n '140,170p' tilelang/profiler/bench.py
sed -n '155,175p' tilelang/profiler/bench.py

Repository: tile-ai/tilelang

Length of output: 2074


🏁 Script executed:

# Check for cache_dtype definition and zero_() calls
sed -n '1,160p' tilelang/profiler/bench.py | tail -50
sed -n '155,250p' tilelang/profiler/bench.py

Repository: tile-ai/tilelang

Length of output: 5380


🏁 Script executed:

# Find device variable initialization and beginning of do_bench function
sed -n '1,160p' tilelang/profiler/bench.py

Repository: tile-ai/tilelang

Length of output: 5667


🏁 Script executed:

# Also check if there's any MPS-specific handling in the file
rg -n 'device.*=' tilelang/profiler/bench.py | head -20
rg -n 'mps' tilelang/profiler/bench.py

Repository: tile-ai/tilelang

Length of output: 445


🏁 Script executed:

# Search for any MPS-specific issues or workarounds in the codebase
rg -n 'int.*mps|mps.*int' tilelang/
rg -n 'TODO|FIXME|BUG|NOTE.*mps' tilelang/profiler/bench.py

Repository: tile-ai/tilelang

Length of output: 314


🏁 Script executed:

# Check if there are any tests for MPS in the codebase
fd -e py | xargs grep -l 'IS_MPS\|test.*mps' 2>/dev/null | head -10

Repository: tile-ai/tilelang

Length of output: 111


🌐 Web query:

PyTorch MPS int dtype support torch.empty int8 int32 zero_

💡 Result:

Short answer: MPS primarily supports floating dtypes (float32/float16, bfloat16 limited) and has incomplete/limited integer dtype support — many integer-only ops (including some tensor creation/backing ops) are not implemented on the MPS backend. As a result, creating int8/int32 tensors directly on MPS with torch.empty(..., dtype=torch.int8/int32, device='mps') or calling in-place integer fills like .zero_() on an MPS integer tensor often fails or is unsupported. See PyTorch MPS docs and multiple MPS issues describing missing dtype/operator support. [1][2][3]

Workarounds:

  • Keep integer tensors on CPU (preferred) and only move floats to MPS.
  • If you must create data on CPU, create with integer dtype on CPU and perform CPU-side ops; avoid moving integer tensors to MPS.
  • Use float tensors on MPS (e.g., float32) and cast/quantize on CPU when needed.

References:

  • PyTorch MPS docs (dtype/operator notes). [1]
  • PyTorch issues & discussions about missing MPS dtype/operator support and limitations. [2][3]

Sources:
[1] PyTorch torch.mps documentation.
[2] pytorch/pytorch GitHub — MPS dtype/operator issues (examples).
[3] pytorch/pytorch GitHub issues discussing MPS missing types/ops.


MPS device incompatibility: integer tensor operations not supported on MPS backend.

The code will fail on MPS devices (Apple Silicon). PyTorch MPS does not support torch.empty() with dtype=torch.int or torch.int8 when device='mps:0', and does not support .zero_() on integer tensors on MPS. This will cause runtime failures when torch.backends.mps.is_available() is true. The cache buffer must be created on CPU or use float32 on MPS, with integer operations handled separately.

🤖 Prompt for AI Agents
In tilelang/profiler/bench.py around line 157, creating an integer cache tensor
directly on an MPS device is incompatible with PyTorch MPS (integer dtypes and
.zero_() are unsupported); detect torch.backends.mps.is_available() and if true
either allocate the cache on CPU (device='cpu') for integer dtypes or allocate a
float32 cache on MPS and map/round/convert integer semantics in CPU-space as
needed; update subsequent code paths to explicitly move or convert the buffer
when interacting with MPS so integer operations happen on CPU (or use float
operations on MPS), and ensure dtype and device selection are centralized so
tests cover both MPS and non-MPS branches.

@LeiWang1999 LeiWang1999 requested a review from oraluben December 27, 2025 05:06
@oraluben
Copy link
Collaborator

@clouds56 Did you run profile on MPS? I think it's blocked by upstream issue pytorch/pytorch#162872. Also extra abstraction on top of event might introduce noice.

@clouds56
Copy link
Contributor Author

I just want to enable profiling on mps.
We could make Event alias of torch.cuda.Event/torch.mps.Event. but this might make type annotation a mess.



IS_CUDA = torch.cuda.is_available()
device = "cuda:0" if IS_CUDA else "mps:0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep these and do same thing for the synchronize function, and use Event/synchronize below, instead of creating a new Event class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, and what's the story for CPU? Shall we create a CpuEvent that uses time? Or leave it unsupported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, and what's the story for CPU? Shall we create a CpuEvent that uses time? Or leave it unsupported?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, and what's the story for CPU? Shall we create a CpuEvent that uses time? Or leave it unsupported?

I didn't test how the cpu backend works, I think we can leave it uncahnged.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tilelang/profiler/bench.py (1)

141-154: MPS backend incompatibility: integer tensor operations are unsupported.

PyTorch MPS does not support torch.empty() with integer dtypes (torch.int/torch.int8) or .zero_() on integer tensors. When device="mps:0", this code will raise a runtime error.

🔎 Proposed fix: allocate cache on CPU for MPS
     cache_size = int(256e6 // 4) if fast_flush else int(256e6)
     cache_dtype = torch.int if fast_flush else torch.int8
-    cache = torch.empty(cache_size, dtype=cache_dtype, device=device)
+    # MPS doesn't support integer tensor operations; use CPU for cache
+    cache_device = "cpu" if IS_MPS else device
+    cache = torch.empty(cache_size, dtype=cache_dtype, device=cache_device)

This keeps the cache on CPU when running on MPS, which still achieves the L2 flush effect for CPU-side caches while avoiding the MPS integer dtype limitation.

Alternatively, if MPS L2 cache flushing is not required, consider skipping cache operations entirely on MPS.

🧹 Nitpick comments (2)
tilelang/profiler/bench.py (2)

60-78: Consider removing unused *args and adding a warning for missing timing data.

The *args parameter on line 63 is never used. Additionally, returning 0.0 silently when record_time is None could mask timing failures during debugging.

🔎 Proposed refinements
-    def __init__(self, *args, **kwargs):
+    def __init__(self, **kwargs):
         self.enable_timing = kwargs.get("enable_timing", False)
         self.record_time = None

For the silent return, consider logging a warning:

def elapsed_time(self, end_event: CpuEvent) -> float:
    if self.record_time is not None and end_event.record_time is not None:
        return (end_event.record_time - self.record_time) / 1e6
    import warnings
    warnings.warn("CpuEvent timing unavailable, returning 0.0")
    return 0.0

208-218: Consider adding early validation for CUPTI backend on non-CUDA devices.

The CUPTI backend is CUDA-only. If a user explicitly requests backend="cupti" on MPS or CPU, the code will fail with an unclear profiler error rather than a descriptive message.

🔎 Proposed guard in do_bench

Add validation before calling _bench_with_cupti:

     # Benchmarking phase
     if backend == "event":
         return _bench_with_events(fn, cache, n_repeat, quantiles, return_mode)
     elif backend == "cupti":
+        if not IS_CUDA:
+            raise ValueError("CUPTI backend requires CUDA; current device is {device}")
         return _bench_with_cupti(fn, cache, n_repeat)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1c8b041 and 17789d9.

📒 Files selected for processing (1)
  • tilelang/profiler/bench.py
🧰 Additional context used
🪛 Ruff (0.14.10)
tilelang/profiler/bench.py

63-63: Unused method argument: args

(ARG002)


170-170: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
tilelang/profiler/bench.py (3)

80-91: LGTM!

The forward reference issue is resolved. Device detection priority (CUDA → MPS → CPU) and the hasattr guard for older PyTorch versions without MPS support are correct.


93-97: LGTM!

The centralized synchronization helper correctly handles CUDA and MPS backends. CPU requires no explicit synchronization since operations are already synchronous.


173-205: Cross-backend timing logic is correct.

The function properly uses the unified Event abstraction and _synchronize(). The cache.zero_() call on line 187 depends on the cache being allocated correctly in do_bench—once the MPS integer tensor issue is fixed upstream, this function will work correctly.

Event = CpuEvent


def _synchronize() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want extra function call here, so might be better assign different impl to synchronize, as Event above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants