Skip to content

Commit c7494e0

Browse files
KshitijLakhanipre-commit-ci[bot]
authored andcommitted
[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v (NVIDIA#1851)
* Add support for Fused Attn MLA head_dim_qk != head_dim_v Modify is_fused_attn_kernel_available() to accept different head_dims for qk and v Modify FusedAttnHelper to accept different head_dims for qk and v and modify assert dims checks in parse_qkv_aval() Modify FusedAttnFwdPrimitive and FusedAttnBwdPrimitive to accept different head_dims for qk and v Modify Fused Attn related cpp and csrc extension API calls to accept different head_dims for qk and v Modify DotProductAttention call() to extract head dims separately for qk and v Modify the FusedAttn Tests to accommodate for API changes in FusedAttn API Add test case for head_dim_qk != head_dim_v (failing) Modify the baseline JAX appropriately to reshape the output vector based on v dims and not q dims Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix context dims in general DPA in test_fused_attn Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Fix dim for output tensor by replacing with v head dim rather than q head dim Add test cases for jax fused attn where head_dim_qk != head_dim_v for a combination of data types and attention type Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Modify the fused attn jax unit test case for head dim qk != head dim v Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Use new FusedAttnRunner function signature for separate hidden dim for qk and v in Fused Attn distributed tests Code clean up Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Fix usage of is_fused_attn signature in distributed tests Signed-off-by: Kshitij Janardan Lakhani <[email protected]> * Remove unnecessary assert Signed-off-by: Kshitij Janardan Lakhani <[email protected]> --------- Signed-off-by: Kshitij Janardan Lakhani <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f39fd7a commit c7494e0

File tree

7 files changed

+220
-110
lines changed

7 files changed

+220
-110
lines changed

tests/jax/test_distributed_fused_attn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def impl_test_self_attn(
8080
seqlen,
8181
seqlen,
8282
hidden,
83+
hidden,
8384
None, # no window
8485
):
8586
pytest.skip("No FusedAttn backend found")
@@ -99,6 +100,7 @@ def impl_test_self_attn(
99100
num_head,
100101
num_head,
101102
hidden,
103+
hidden,
102104
attn_bias_type,
103105
attn_mask_type,
104106
dropout_prob,
@@ -227,6 +229,7 @@ def test_cross_attn(
227229
seqlen,
228230
seqlen,
229231
hidden,
232+
hidden,
230233
None, # no window
231234
):
232235
pytest.skip("No FusedAttn backend found")
@@ -239,6 +242,7 @@ def test_cross_attn(
239242
num_head,
240243
num_head,
241244
hidden,
245+
hidden,
242246
attn_bias_type,
243247
attn_mask_type,
244248
dropout_prob,
@@ -329,6 +333,7 @@ def impl_test_context_parallel_attn(
329333
num_head,
330334
num_kv_heads,
331335
hidden,
336+
hidden,
332337
attn_bias_type,
333338
attn_mask_type,
334339
dropout_prob,
@@ -360,6 +365,7 @@ def check_has_backend_for_mask(mask_type):
360365
seqlen,
361366
seqlen,
362367
hidden,
368+
hidden,
363369
None,
364370
) # no SWA for CP
365371

tests/jax/test_fused_attn.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def general_dot_product_attention(
106106
softmax_out = softmax_out * multiplier
107107

108108
context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
109-
context = jnp.reshape(context, query.shape)
109+
context_shape = query.shape[:-1] + (value.shape[-1],)
110+
context = jnp.reshape(context, context_shape)
110111
return context
111112

112113

@@ -294,7 +295,8 @@ class FusedAttnRunner:
294295
max_seqlen_kv: int
295296
num_heads_q: int
296297
num_heads_kv: int
297-
head_dim: int
298+
head_dim_qk: int
299+
head_dim_v: int
298300
attn_bias_type: AttnBiasType
299301
attn_mask_type: AttnMaskType
300302
dropout_prob: float
@@ -346,6 +348,14 @@ def _check_configs(self):
346348
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
347349
)
348350

351+
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
352+
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
353+
if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
354+
pytest.skip(
355+
"For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
356+
"is either BSHD_BSHD_BSHD or THD_THD_THD"
357+
)
358+
349359
self.backend = FusedAttnHelper(
350360
self.is_training,
351361
self.dtype,
@@ -358,7 +368,8 @@ def _check_configs(self):
358368
self.num_heads_kv,
359369
self.max_seqlen_q,
360370
self.max_seqlen_kv,
361-
self.head_dim,
371+
self.head_dim_qk,
372+
self.head_dim_v,
362373
(-1, -1) if self.window_size is None else self.window_size,
363374
).get_fused_attn_backend()
364375
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
@@ -391,13 +402,9 @@ def _setup_inputs(self):
391402
key = jax.random.PRNGKey(0)
392403
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
393404

394-
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
395-
k_shape = v_shape = (
396-
self.batch_size,
397-
self.max_seqlen_kv,
398-
self.num_heads_kv,
399-
self.head_dim,
400-
)
405+
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
406+
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
407+
v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
401408

402409
if self.attn_bias_type == AttnBiasType.NO_BIAS:
403410
bias_shape = None
@@ -616,7 +623,7 @@ def generate_random_segment_ids(
616623
raise ValueError(f"Unknown {self.seq_desc_format=}")
617624

618625
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
619-
self.scaling_factor = 1.0 / sqrt(self.head_dim)
626+
self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
620627

621628
# Setup distributed sharding specs
622629
# Setup shardings for distributed tests
@@ -935,21 +942,45 @@ def check_dqkv(primitive, reference, pad, idx):
935942
],
936943
)
937944
@pytest.mark.parametrize(
938-
"b, s_q, s_kv, h_q, h_kv, d, dtype",
945+
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
939946
[
940-
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
947+
pytest.param(
948+
2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
949+
),
950+
pytest.param(
951+
2,
952+
2048,
953+
1024,
954+
12,
955+
12,
956+
64,
957+
64,
958+
jnp.bfloat16,
959+
id="2-2048-1024-12-12-64-64-BF16-CROSS",
960+
),
961+
pytest.param(
962+
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
963+
),
964+
pytest.param(
965+
4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
966+
),
967+
pytest.param(
968+
4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
969+
),
941970
pytest.param(
942971
2,
943972
2048,
944973
1024,
945974
12,
946975
12,
947976
64,
977+
32,
948978
jnp.bfloat16,
949-
id="2-2048-1024-12-12-64-BF16-CROSS",
979+
id="2-2048-1024-12-12-64-32-BF16-CROSS",
980+
),
981+
pytest.param(
982+
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
950983
),
951-
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
952-
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
953984
],
954985
)
955986
@pytest.mark.parametrize(
@@ -1003,7 +1034,8 @@ def _test_forward(
10031034
s_kv,
10041035
h_q,
10051036
h_kv,
1006-
d,
1037+
d_qk,
1038+
d_v,
10071039
attn_bias_type,
10081040
attn_mask_type,
10091041
dropout_prob,
@@ -1028,7 +1060,8 @@ def _test_forward(
10281060
s_kv,
10291061
h_q,
10301062
h_kv,
1031-
d,
1063+
d_qk,
1064+
d_v,
10321065
attn_bias_type,
10331066
attn_mask_type,
10341067
dropout_prob,
@@ -1055,7 +1088,8 @@ def test_backward(
10551088
s_kv,
10561089
h_q,
10571090
h_kv,
1058-
d,
1091+
d_qk,
1092+
d_v,
10591093
attn_bias_type,
10601094
attn_mask_type,
10611095
dropout_prob,
@@ -1077,7 +1111,8 @@ def test_backward(
10771111
s_kv,
10781112
h_q,
10791113
h_kv,
1080-
d,
1114+
d_qk,
1115+
d_v,
10811116
attn_bias_type,
10821117
attn_mask_type,
10831118
dropout_prob,

transformer_engine/jax/attention.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class ReorderStrategy(Enum):
188188
189189
- DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between
190190
GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the
191-
mulitple of 2 * cp_size.
191+
multiple of 2 * cp_size.
192192
Examples:
193193
- Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15];
194194
- After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3]
@@ -288,7 +288,8 @@ def is_fused_attn_kernel_available(
288288
kv_num_heads,
289289
q_max_seqlen,
290290
kv_max_seqlen,
291-
head_dim,
291+
head_dim_qk,
292+
head_dim_v,
292293
window_size: Optional[Tuple[int, int]] = None,
293294
):
294295
"""
@@ -308,7 +309,8 @@ def make_helper(attn_mask_type):
308309
kv_num_heads,
309310
q_max_seqlen,
310311
kv_max_seqlen,
311-
head_dim,
312+
head_dim_qk,
313+
head_dim_v,
312314
(-1, -1) if window_size is None else window_size,
313315
)
314316

@@ -491,7 +493,7 @@ def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type):
491493

492494
@jax.tree_util.register_pytree_node_class
493495
class SequenceDescriptor:
494-
"""A class to descibe the sequences with flexible initialization.
496+
"""A class to describe the sequences with flexible initialization.
495497
- SequenceDescriptor.from_seqlens
496498
For non-THD (non-packed) cases, where each batch has only 1 sequence.
497499
- SequenceDescriptor.from_seqlens_and_offsets

0 commit comments

Comments
 (0)