-
Notifications
You must be signed in to change notification settings - Fork 3k
cudnn dot_product_attention encounters Failed to capture gpu graph when running on 2 NVIDIA A6000 Ada cards #27599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
couldn't reproduce on A100/H100. Will try to find a A6000 machine to test. |
I can't reproduce the error on H200 either. However, on a A6000 machine same error occurs. Please see the sys info and error message below.
|
Hi, could you try with
to generate the cudnn logs for me? |
could you also try running with compute sanitizer?
|
Here you go
|
Unfortunately I don't have access to the A6000 cards right now and the error messages aren't clear about the errors. I tried on similar card A40 but it passed. I noticed that you are using cudnn 9.1, could you try with cudnn 9.8 if possible? |
Updating cudnn to 9.8.0 does solve the issue. Looks like my conda has had some conflict in different cudnn versions. Thank you @Cjkkkk for the help! |
Description
Hi,
I was trying using the
jax.nn.dot_product_attention
with theimplementation=cudnn
for theflax.linen.MultiHeadDotProductAttention
. However, I encountered error message as following when I attempted to run on multiple NVIDIA GPUs. The code works fine on only 1 GPU, though. Here is a minimal reproducing example from my end.And here is the error message.
Thank you in advance for the help!
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: