-
Notifications
You must be signed in to change notification settings - Fork 967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Performance of mx.fast.scaled_dot_product_attention #1193
Comments
The performance differences you are seeing is likely due to implicit casting. When you call sdpa it will promot |
Thanks @awni for the reply! I did another experiment to explicitly cast the dtypes to float32 before calling sdpa. Here's the code and output import time
import mlx.core as mx
def benchmark_sdpa(q: mx.array, k: mx.array, v: mx.array, mask: mx.array):
mx.eval(q, k, v, mask)
for _ in range(100):
mx.eval(mx.fast.scaled_dot_product_attention(
q, k, v, scale=1, mask=mask
))
toi = time.perf_counter()
for _ in range(3200):
mx.eval(mx.fast.scaled_dot_product_attention(q, k, v, scale=1, mask=mask))
toc = time.perf_counter()
tpi = 1e3 * (toc - toi) / 3200
print(f"q {q.dtype}, k {k.dtype}, v {v.dtype}, mask {mask.dtype} takes {tpi} ms")
def benchmark_sdpa_with_explicit_dtype_cast(q: mx.array, k: mx.array, v: mx.array, mask: mx.array):
mx.eval(q, k, v, mask)
for _ in range(100):
mx.eval(mx.fast.scaled_dot_product_attention(
q.astype(mask.dtype), k.astype(mask.dtype), v.astype(mask.dtype), scale=1, mask=mask
))
toi = time.perf_counter()
for _ in range(3200):
mx.eval(mx.fast.scaled_dot_product_attention(q.astype(mask.dtype), k.astype(mask.dtype), v.astype(mask.dtype), scale=1, mask=mask))
toc = time.perf_counter()
tpi = 1e3 * (toc - toi) / 3200
print(f"all casted to {mask.dtype}, q {q.dtype}, k {k.dtype}, v {v.dtype}, mask {mask.dtype} takes {tpi} ms")
if __name__ == "__main__":
print(mx.default_device())
q = mx.random.uniform(shape=(1, 32, 1, 4096 // 32)).astype(mx.bfloat16)
k = mx.random.uniform(shape=(1, 32, 16, 4096 // 32)).astype(mx.bfloat16)
v = mx.random.uniform(shape=(1, 32, 16, 4096 // 32)).astype(mx.bfloat16)
mask = mx.zeros(shape=(1, 16)).astype(mx.bfloat16)
mx.eval(q, k, v, mask)
for i in range(5):
mx.metal.clear_cache()
print(f"run {i}")
# O = softmax(Q @ K.T * scale + mask, dim=-1) @ V
benchmark_sdpa(q, k, v, mask)
benchmark_sdpa(q, k, v, mask.astype(mx.float32))
benchmark_sdpa(q, k, v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa(q.astype(mx.float32), k, v, mask.astype(mx.float32))
benchmark_sdpa(q.astype(mx.float32), k, v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa(q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q, k, v, mask)
benchmark_sdpa_with_explicit_dtype_cast(q, k, v, mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q, k, v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k, v, mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k, v.astype(mx.float32), mask.astype(mx.float32))
benchmark_sdpa_with_explicit_dtype_cast(q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32), mask.astype(mx.float32)) Output
Please correct me if I'm wrong - Based on the output, it appears that casting before sdpa can actually speed up the process. Could you help me understand why implicit casting inside sdpa seems slower than explicit casting outside of it? I didn't expect it to make such a difference. |
Sorry, there's a lot of results there, I'm not sure what you are looking at. Could you point me to the two cases that are unexpected to you? |
Sure! For example
It appears that casting all inputs to |
I don't see the same results on my M1 Max. They look pretty similar though there is some variance in the timings in general:
|
Describe the bug
I implemented a model with
mx.fast.scaled_dot_product_attention
but observed the performance improves significantly when the I applymask=mask.astype(q.dtype)
before sdpa. The model dtype is bfloat16, and the mask before applyingastype
has dtype float32.Thus I did an experiment to perf test the
mx.fast.scaled_dot_product_attention
with different input dtypes.To Reproduce
Output
Expected behavior
Before this experiment, I thought
bfloat16
would be the fastest, andfloat32
would be the slowest.However, from the experiment result, it seems to me that
mx.fast.scaled_dot_product_attention
can be fast whenq, k, v
andmask
have the samedtype
(eitherbfloat16
andfloat32
). Whenq, k, v, mask
have different dtypes, the computation seems slower. Could you help me understand the reason? Is it due to the implicit dtype conversion?Comparing with torch sdpa requires q, k, v and mask to have the same dtype, I also wonder if it is intended to have an implicit dtype conversion in mlx?
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: