Skip to content

Commit ed071aa

Browse files
Modify the fused attn jax unit test case for head dim qk != head dim v
Signed-off-by: Kshitij Janardan Lakhani <[email protected]>
1 parent 3547a9f commit ed071aa

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/jax/test_fused_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,7 @@ def check_dqkv(primitive, reference, pad, idx):
978978
id="2-2048-1024-12-12-64-32-BF16-CROSS",
979979
),
980980
pytest.param(
981-
2, 2048, 2048, 12, 6, 64, 32, jnp.float16, id="2-2048-2048-12-6-64-32-FP16-GQA"
981+
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
982982
),
983983
],
984984
)

0 commit comments

Comments
 (0)