-
Notifications
You must be signed in to change notification settings - Fork 359
improve benchmark on mps #1547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
improve benchmark on mps #1547
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds defensive handling when kernel adapters lack on-disk artifacts, exposes Metal kernel source retrieval via a new adapter method, and implements cross-backend (CUDA/MPS/CPU) timing in the profiler with unified Event abstraction, centralized synchronization, and device selection. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Bench as bench.py
participant Device
participant Timer as Event
participant Sync as _synchronize
rect rgb(220,235,245)
Note over Caller,Bench: Start benchmark → detect backend
Caller->>Bench: do_bench(...)
Bench->>Device: detect (cuda / mps / cpu)
Bench->>Bench: set `device`, set `IS_MPS`
end
rect rgb(250,240,220)
Note over Caller,Timer: Warmup (record + run + sync)
Caller->>Timer: Event.record(start)
Caller->>Device: execute warmup on `device`
Caller->>Sync: _synchronize()
Sync->>Device: backend-aware sync
end
rect rgb(230,245,230)
Note over Caller,Timer: Timed iterations (record, run, record, sync)
loop iterations
Caller->>Timer: Event.record(start)
Caller->>Device: run iteration on `device`
Caller->>Timer: Event.record(end)
Caller->>Sync: _synchronize()
end
Timer->>Caller: elapsed_time() -> duration
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tilelang/cache/kernel_cache.pytilelang/jit/adapter/torch/metal.pytilelang/profiler/bench.py
🔇 Additional comments (5)
tilelang/jit/adapter/torch/metal.py (1)
29-31: LGTM! Clear documentation of Metal backend behavior.The comment effectively explains why
libpathis set toNone— Metal kernels are compiled at runtime viatorch.mps.compile_shaderwithout producing a cacheable shared library. This aligns with the defensivegetattr(..., "libpath", None)check inkernel_cache.py.tilelang/cache/kernel_cache.py (1)
410-414: Defensive handling for backends without cacheable binaries.The
getattrwithNonedefault correctly avoidsAttributeErrorfor adapters missinglibpath, and the early return prevents attempting to cache non-existent artifacts.However, note that
_get_required_files()(lines 419-422) still includeskernel_lib_pathas a required file for cache loading. This means Metal kernels will never load from disk cache since the.sofile won't exist. If disk caching is intended for Metal in the future (e.g., caching the source code), consider making the library file optional in_get_required_files()when the backend doesn't produce one.tilelang/profiler/bench.py (3)
107-111: LGTM! Clean synchronization abstraction.The
_synchronize()helper correctly dispatches to the appropriate backend. The implicit no-op for CPU (noelsebranch) is intentional and correct since CPU operations are synchronous.
151-168: Correct usage of new timing abstractions.The benchmark setup properly uses
_synchronize()and the newEventclass. The cache buffer allocation now correctly uses the detecteddevicevariable instead of hardcoding CUDA.
196-207: Event-based timing correctly uses the new abstraction.The per-iteration event recording and final synchronization via
_synchronize()is appropriate. The global sync at line 207 ensures all events complete before collecting elapsed times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tilelang/profiler/bench.py (2)
97-105: Device detection logic correct, but must move beforeEventclass.The priority order (CUDA → MPS → CPU) is appropriate, and the
hasattrcheck fortorch.backends.mpsensures compatibility with older PyTorch versions.As noted above, move these definitions before the
Eventclass definition to resolve the forward reference error.
60-95: Critical: Forward reference error still present.The
Eventclass referencesIS_CUDA(line 65) andIS_MPS(line 67) before they are defined at lines 97-98. This will cause aNameErrorat runtime when the class is instantiated.Move the definitions before the class:
+IS_CUDA = torch.cuda.is_available() +IS_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + class Event:Then remove the duplicate definitions at lines 97-98.
Additional concerns:
Docstring is understated (line 61): The description "Dummy event class for CPU timing compatibility" does not accurately reflect that this is a full cross-backend timing abstraction for CUDA, MPS, and CPU.
Silent fallback (line 94): Returning
0.0when timing is unavailable could mask failures. Consider logging a warning to make timing issues visible during debugging.
🧹 Nitpick comments (2)
tilelang/profiler/bench.py (2)
187-207: Consider renaming_bench_with_cuda_eventsfor clarity.The function implementation correctly uses the new
Eventwrapper and_synchronize()for cross-backend support. However, the function name_bench_with_cuda_eventsis now misleading since it supports CUDA, MPS, and CPU.Consider renaming to
_bench_with_eventsor_bench_with_timing_eventsto reflect its broader scope.🔎 Proposed rename
-def _bench_with_cuda_events( +def _bench_with_events( fn: Callable, cache: torch.Tensor, n_repeat: int, quantiles: list[float] | None, return_mode: str, ) -> float | list[float]: - """Benchmark using CUDA events for timing.""" + """Benchmark using Event-based timing across backends."""And update the call site at line 180:
- return _bench_with_cuda_events(fn, cache, n_repeat, quantiles, return_mode) + return _bench_with_events(fn, cache, n_repeat, quantiles, return_mode)
157-157: Consider conditionally allocating the L2 cache flush buffer only for CUDA.The cache buffer is now created on all backends (line 157), but the L2 cache flushing technique is specific to NVIDIA GPUs. On MPS, the different cache hierarchy means
cache.zero_()may not provide the intended consistent cache state. On CPU, automatic cache management means zeroing a buffer adds unnecessary memory pressure without providing deterministic cache flushing.Either document why the cache buffer benefits MPS and CPU, or allocate it only when
IS_CUDAis true.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/profiler/bench.py
🔇 Additional comments (4)
tilelang/profiler/bench.py (4)
7-7: LGTM: Import supports CPU timing fallback.The
timemodule is correctly imported to enableperf_counter_ns()for CPU timing in theEventclass.
107-111: LGTM: Clean synchronization abstraction.The
_synchronize()helper centralizes backend-specific synchronization logic and eliminates repetitive conditionals throughout the codebase.
151-151: LGTM: Consistent use of cross-backend synchronization.Correctly replaced direct CUDA synchronization with the new
_synchronize()helper.
160-167: LGTM: Correct migration to cross-backend timing.The timing estimation correctly uses the new
Eventwrapper and_synchronize()helper, maintaining the same logic while enabling MPS and CPU support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/profiler/bench.py (1)
187-194: Update the docstring to reflect cross-backend support.The docstring "Benchmark using CUDA events for timing" is outdated. The function now uses the cross-backend
Eventabstraction that supports CUDA, MPS, and CPU timing.🔎 Proposed fix
def _bench_with_events( fn: Callable, cache: torch.Tensor, n_repeat: int, quantiles: list[float] | None, return_mode: str, ) -> float | list[float]: - """Benchmark using CUDA events for timing.""" + """Benchmark using cross-backend timing events (CUDA/MPS/CPU)."""
♻️ Duplicate comments (1)
tilelang/profiler/bench.py (1)
70-71: Update the docstring to reflect cross-backend support.The docstring "Dummy event class for CPU timing compatibility" understates the purpose of this class. It's actually a full cross-backend timing abstraction that unifies CUDA, MPS, and CPU timing, not just a dummy placeholder.
🔎 Proposed fix
-class Event: - """Dummy event class for CPU timing compatibility.""" +class Event: + """Cross-backend timing event abstraction for CUDA, MPS, and CPU."""
🧹 Nitpick comments (2)
tilelang/jit/adapter/torch/metal.py (1)
76-77: Consider handling thekernel_onlyparameter explicitly or documenting its omission.The
kernel_onlyparameter is currently unused (flagged by static analysis), and the default value differs from the base class (Falsevs.True). While Metal may only have a single source variant (similar to CuTeDSL), the implementation could be clearer.Consider one of these approaches:
Explicit handling (even if both branches return the same value):
def get_kernel_source(self, kernel_only: bool = True) -> str: # Metal only has kernel_global_source for both host and device return self.kernel_global_source or ""Document the intentional no-op if Metal truly doesn't distinguish:
def get_kernel_source(self, kernel_only: bool = True) -> str: # Metal backend has unified source; kernel_only is ignored return self.kernel_global_source or ""Also, align the default value with
BaseKernelAdapter(Trueinstead ofFalse) for consistency across adapters.Based on learnings, CuTeDSL similarly has unified source, but being explicit about the behavior improves maintainability.
tilelang/profiler/bench.py (1)
98-104: Consider warning when timing is unavailable.The method silently returns
0.0when timing is unavailable (line 104). This could hide timing failures and make debugging difficult. Consider logging a warning to alert users that timing data is not available.🔎 Proposed fix
Add logging import at the top of the file:
+import logging import os import sysThen update the elapsed_time method:
def elapsed_time(self, end_event: Event) -> float: if self.inner is not None and end_event.inner is not None: return self.inner.elapsed_time(end_event.inner) # type: ignore elif self.record_time is not None and end_event.record_time is not None: return (end_event.record_time - self.record_time) / 1e6 # Convert ns to ms else: + logging.warning("Timing unavailable: events were not properly recorded") return 0.0
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/jit/adapter/torch/metal.pytilelang/profiler/bench.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-26T06:45:47.669Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:47.669Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.
Applied to files:
tilelang/jit/adapter/torch/metal.py
🧬 Code graph analysis (1)
tilelang/jit/adapter/torch/metal.py (4)
tilelang/jit/adapter/cython/adapter.py (2)
libpath(366-368)get_kernel_source(380-387)tilelang/jit/kernel.py (1)
get_kernel_source(438-449)tilelang/jit/adapter/nvrtc/adapter.py (1)
get_kernel_source(179-190)tilelang/jit/adapter/base.py (1)
get_kernel_source(89-93)
🪛 Ruff (0.14.10)
tilelang/jit/adapter/torch/metal.py
76-76: Unused method argument: kernel_only
(ARG002)
tilelang/profiler/bench.py
184-184: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (9)
tilelang/jit/adapter/torch/metal.py (1)
29-31: Good defensive handling for Metal's artifact-less compilation.The explicit
libpath = Nonewith a clear explanation prevents issues when the kernel cache attempts to retrieve the library path. This aligns well with Metal's shader compilation model viatorch.mps.compile_shader.tilelang/profiler/bench.py (8)
7-7: LGTM!The
timeimport is correctly added to support CPU timing fallback viaperf_counter_ns()in the Event class.
60-67: LGTM! Forward reference issue resolved.The device detection logic is correct and properly ordered (CUDA → MPS → CPU). The
hasattrcheck ensures compatibility with older PyTorch versions. Moving these definitions before theEventclass successfully resolves the previous forward reference error.
107-111: LGTM!The
_synchronize()helper correctly centralizes synchronization logic across backends. The ordering (CUDA → MPS → implicit CPU) is appropriate since CPU execution is synchronous by default.
151-151: LGTM!Correctly replaced
torch.cuda.synchronize()with the new cross-backend_synchronize()helper.
160-161: LGTM!The timing estimation correctly uses the new
Eventwrapper and_synchronize()helper for cross-backend compatibility.Also applies to: 167-167
180-180: LGTM!Correctly updated to call the renamed
_bench_with_eventsfunction.
196-197: LGTM!The function correctly uses the new
Eventwrapper for cross-backend timing and_synchronize()for proper synchronization across all backends.Also applies to: 207-207
88-96: Reconsider using per-event synchronization instead of global MPS synchronization.The code currently uses
torch.mps.synchronize()(global device sync) as the primary approach, with a comment claiming per-event synchronization can hang. However, PyTorch MPS documentation actually recommends the opposite: prefertorch.mps.event.Eventfor fine-grained synchronization when you only need to wait for specific work. Globaltorch.mps.synchronize()should be reserved as a fallback, as it can hide deadlocks and hangs during kernel errors and is less efficient than event-based synchronization. Consider usingevent.synchronize()instead to only block on the specific event that needs completion.Likely an incorrect or invalid review comment.
| cache_size = int(256e6 // 4) if fast_flush else int(256e6) | ||
| cache_dtype = torch.int if fast_flush else torch.int8 | ||
| cache = torch.empty(cache_size, dtype=cache_dtype, device="cuda") | ||
| cache = torch.empty(cache_size, dtype=cache_dtype, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cat -n timelang/profiler/bench.py | head -250 | tail -100Repository: tile-ai/tilelang
Length of output: 118
🏁 Script executed:
# Find the bench.py file first
fd bench.pyRepository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
wc -l tilelang/profiler/bench.pyRepository: tile-ai/tilelang
Length of output: 90
🏁 Script executed:
# Read the section around line 157 and check cache operations
sed -n '140,170p' tilelang/profiler/bench.py
sed -n '155,175p' tilelang/profiler/bench.pyRepository: tile-ai/tilelang
Length of output: 2074
🏁 Script executed:
# Check for cache_dtype definition and zero_() calls
sed -n '1,160p' tilelang/profiler/bench.py | tail -50
sed -n '155,250p' tilelang/profiler/bench.pyRepository: tile-ai/tilelang
Length of output: 5380
🏁 Script executed:
# Find device variable initialization and beginning of do_bench function
sed -n '1,160p' tilelang/profiler/bench.pyRepository: tile-ai/tilelang
Length of output: 5667
🏁 Script executed:
# Also check if there's any MPS-specific handling in the file
rg -n 'device.*=' tilelang/profiler/bench.py | head -20
rg -n 'mps' tilelang/profiler/bench.pyRepository: tile-ai/tilelang
Length of output: 445
🏁 Script executed:
# Search for any MPS-specific issues or workarounds in the codebase
rg -n 'int.*mps|mps.*int' tilelang/
rg -n 'TODO|FIXME|BUG|NOTE.*mps' tilelang/profiler/bench.pyRepository: tile-ai/tilelang
Length of output: 314
🏁 Script executed:
# Check if there are any tests for MPS in the codebase
fd -e py | xargs grep -l 'IS_MPS\|test.*mps' 2>/dev/null | head -10Repository: tile-ai/tilelang
Length of output: 111
🌐 Web query:
PyTorch MPS int dtype support torch.empty int8 int32 zero_
💡 Result:
Short answer: MPS primarily supports floating dtypes (float32/float16, bfloat16 limited) and has incomplete/limited integer dtype support — many integer-only ops (including some tensor creation/backing ops) are not implemented on the MPS backend. As a result, creating int8/int32 tensors directly on MPS with torch.empty(..., dtype=torch.int8/int32, device='mps') or calling in-place integer fills like .zero_() on an MPS integer tensor often fails or is unsupported. See PyTorch MPS docs and multiple MPS issues describing missing dtype/operator support. [1][2][3]
Workarounds:
- Keep integer tensors on CPU (preferred) and only move floats to MPS.
- If you must create data on CPU, create with integer dtype on CPU and perform CPU-side ops; avoid moving integer tensors to MPS.
- Use float tensors on MPS (e.g., float32) and cast/quantize on CPU when needed.
References:
- PyTorch MPS docs (dtype/operator notes). [1]
- PyTorch issues & discussions about missing MPS dtype/operator support and limitations. [2][3]
Sources:
[1] PyTorch torch.mps documentation.
[2] pytorch/pytorch GitHub — MPS dtype/operator issues (examples).
[3] pytorch/pytorch GitHub issues discussing MPS missing types/ops.
MPS device incompatibility: integer tensor operations not supported on MPS backend.
The code will fail on MPS devices (Apple Silicon). PyTorch MPS does not support torch.empty() with dtype=torch.int or torch.int8 when device='mps:0', and does not support .zero_() on integer tensors on MPS. This will cause runtime failures when torch.backends.mps.is_available() is true. The cache buffer must be created on CPU or use float32 on MPS, with integer operations handled separately.
🤖 Prompt for AI Agents
In tilelang/profiler/bench.py around line 157, creating an integer cache tensor
directly on an MPS device is incompatible with PyTorch MPS (integer dtypes and
.zero_() are unsupported); detect torch.backends.mps.is_available() and if true
either allocate the cache on CPU (device='cpu') for integer dtypes or allocate a
float32 cache on MPS and map/round/convert integer semantics in CPU-space as
needed; update subsequent code paths to explicitly move or convert the buffer
when interacting with MPS so integer operations happen on CPU (or use float
operations on MPS), and ensure dtype and device selection are centralized so
tests cover both MPS and non-MPS branches.
|
@clouds56 Did you run profile on MPS? I think it's blocked by upstream issue pytorch/pytorch#162872. Also extra abstraction on top of event might introduce noice. |
|
I just want to enable profiling on mps. |
|
|
||
|
|
||
| IS_CUDA = torch.cuda.is_available() | ||
| device = "cuda:0" if IS_CUDA else "mps:0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can keep these and do same thing for the synchronize function, and use Event/synchronize below, instead of creating a new Event class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, and what's the story for CPU? Shall we create a CpuEvent that uses time? Or leave it unsupported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, and what's the story for CPU? Shall we create a CpuEvent that uses time? Or leave it unsupported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, and what's the story for CPU? Shall we create a CpuEvent that uses
time? Or leave it unsupported?
I didn't test how the cpu backend works, I think we can leave it uncahnged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tilelang/profiler/bench.py (1)
141-154: MPS backend incompatibility: integer tensor operations are unsupported.PyTorch MPS does not support
torch.empty()with integer dtypes (torch.int/torch.int8) or.zero_()on integer tensors. Whendevice="mps:0", this code will raise a runtime error.🔎 Proposed fix: allocate cache on CPU for MPS
cache_size = int(256e6 // 4) if fast_flush else int(256e6) cache_dtype = torch.int if fast_flush else torch.int8 - cache = torch.empty(cache_size, dtype=cache_dtype, device=device) + # MPS doesn't support integer tensor operations; use CPU for cache + cache_device = "cpu" if IS_MPS else device + cache = torch.empty(cache_size, dtype=cache_dtype, device=cache_device)This keeps the cache on CPU when running on MPS, which still achieves the L2 flush effect for CPU-side caches while avoiding the MPS integer dtype limitation.
Alternatively, if MPS L2 cache flushing is not required, consider skipping cache operations entirely on MPS.
🧹 Nitpick comments (2)
tilelang/profiler/bench.py (2)
60-78: Consider removing unused*argsand adding a warning for missing timing data.The
*argsparameter on line 63 is never used. Additionally, returning0.0silently whenrecord_timeisNonecould mask timing failures during debugging.🔎 Proposed refinements
- def __init__(self, *args, **kwargs): + def __init__(self, **kwargs): self.enable_timing = kwargs.get("enable_timing", False) self.record_time = NoneFor the silent return, consider logging a warning:
def elapsed_time(self, end_event: CpuEvent) -> float: if self.record_time is not None and end_event.record_time is not None: return (end_event.record_time - self.record_time) / 1e6 import warnings warnings.warn("CpuEvent timing unavailable, returning 0.0") return 0.0
208-218: Consider adding early validation for CUPTI backend on non-CUDA devices.The CUPTI backend is CUDA-only. If a user explicitly requests
backend="cupti"on MPS or CPU, the code will fail with an unclear profiler error rather than a descriptive message.🔎 Proposed guard in do_bench
Add validation before calling
_bench_with_cupti:# Benchmarking phase if backend == "event": return _bench_with_events(fn, cache, n_repeat, quantiles, return_mode) elif backend == "cupti": + if not IS_CUDA: + raise ValueError("CUPTI backend requires CUDA; current device is {device}") return _bench_with_cupti(fn, cache, n_repeat)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/profiler/bench.py
🧰 Additional context used
🪛 Ruff (0.14.10)
tilelang/profiler/bench.py
63-63: Unused method argument: args
(ARG002)
170-170: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
tilelang/profiler/bench.py (3)
80-91: LGTM!The forward reference issue is resolved. Device detection priority (CUDA → MPS → CPU) and the
hasattrguard for older PyTorch versions without MPS support are correct.
93-97: LGTM!The centralized synchronization helper correctly handles CUDA and MPS backends. CPU requires no explicit synchronization since operations are already synchronous.
173-205: Cross-backend timing logic is correct.The function properly uses the unified
Eventabstraction and_synchronize(). Thecache.zero_()call on line 187 depends on the cache being allocated correctly indo_bench—once the MPS integer tensor issue is fixed upstream, this function will work correctly.
| Event = CpuEvent | ||
|
|
||
|
|
||
| def _synchronize() -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want extra function call here, so might be better assign different impl to synchronize, as Event above.
Summary by CodeRabbit
New Features
Bug Fixes
✏️ Tip: You can customize this high-level summary in your review settings.