-
-
Notifications
You must be signed in to change notification settings - Fork 11k
Description
🚀 The feature, motivation and pitch
For low concurrency cases, the kernels running after the all reduce and before the first fused_moe kernel, take the same amount of time as the first fused_moe kernel. These kernels all run with very few CTAs in the low concurrency case utilizing only a small fraction of the GPU, and the number of kernels means we lose a lot of time to inter-kernel spacing as well.
For concurrency 4 on H200 with EP8:
triton_red_Fused__to_copy_add_mean_mul_pow_rqsrt_0: 3.42us
nvjet_tst_64x8_64x16_4x1_v_bz_splitK_TNT: 4.64us
splitKReduce_kernel: 1.66us
per_token_group_quant_8bit_kernel: 2.18us
sm90_fp8_gemm_1d2d_impl: 8.77us
elementwise_kernel: 2.18us
elementwise_kernel: 2.02us
per_token_group_quant_8bit_kernel: 1.73us
sm90_fp8_gemm_1d2d_impl: 2.78us
triton_poi_fused__to_copy_add_sigmoid_0: 1.31us
topk_with_k2_kernel: 1.34us
group_idx_and_topk_idx_kernel: 6.50us
triton_poi_fused__to_copy_0: 0.93us
per_token_group_quant_8bit_kernel: 2.08us
moe_align_block_size_kernel: 3.87us
count_and_sort_expert_tokens_kernel: 1.25us
unrolled_elementwise_kernel: 1.60us
index_elementwise_kernel: 2.21us
Total with inter-kernel time: 52.1us.
fused_moe: 51.1us
Where possible for the low-concurrency case, we should fuse these kernels to bring the time down.
Alternatives
TensorRT-LLM launches the equivalent kernels on two separate streams to try and better utilize the GPU. It also has some of these fused. As a result it runs in ~25us.
Additional context
This is mentioned in #24629 as potential fusion 8.
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.