implement group gemm for contiguous case#136
implement group gemm for contiguous case#136XiaobingSuper wants to merge 5 commits intoHazyResearch:mainfrom
Conversation
|
|
||
| if (args.common.is_tma_multicast_valid) { | ||
| if (cluster_ctarank() == 0) { | ||
| tma::cluster::expect(args.inputs_cluster_arrived, 0, args.input.b); |
There was a problem hiding this comment.
@benjaminfspector @DanFu09 @simran-arora Currently, the non-multicast implementation can reach 80% of the performance of deepgeem, but there is a performance regression when using TMA multicast. I tested the deepgemm TMA multicast path, which can achieve a 10%-20% performance improvement. It may have an issue with my code. Could you help review it? Thanks!
There was a problem hiding this comment.
@DanFu09 @benjaminfspector @simran-arora @StuartSul Any suggestions? Thanks!
|
See this: https://github.com/HazyResearch/ThunderKittens/pull/98/files You probably want to use larger wgmma ops. |
Yes, I tried it, it can get a better performance using a bigger N block size(deepgemm uses 192 block size), but I want to use TMA multicast feature to reduce the load of gloable memory to share memory. I tested the deepgemm code, the TMA multicast path can get 10%~20% performance improvement. Thanks! |
84040e3 to
0683cc0
Compare
|
Current performance is: Testing grouped contiguous GEMM for deepgemm(block_m=block_k=128,block_n=192):
> Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 1408 us | throughput: 1367 TFLOPS, 441 GB/s
> Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048): 749 us | throughput: 1285 TFLOPS, 796 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 1411 us | throughput: 1364 TFLOPS, 523 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048): 752 us | throughput: 1279 TFLOPS, 870 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168): 465 us | throughput: 1036 TFLOPS, 2294 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048): 244 us | throughput: 985 TFLOPS, 2475 GB/s
Testing grouped contiguous GEMM for deepgemm(block_m=block_n=block_k=128):
> Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 1505 us | throughput: 1279 TFLOPS, 412 GB/s
> Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048): 796 us | throughput: 1209 TFLOPS, 748 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 1504 us | throughput: 1279 TFLOPS, 491 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048): 798 us | throughput: 1206 TFLOPS, 820 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168): 482 us | throughput: 998 TFLOPS, 2211 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048): 256 us | throughput: 939 TFLOPS, 2358 GB/s
Testing grouped contiguous GEMM for tk(block_m=block_n=block_k=128)):
> Perf (num_groups= 4, expected_m_per_group=8192, n=4096, k=7168): 2150 us | throughput: 895 TFLOPS, 289 GB/s
> Perf (num_groups= 4, expected_m_per_group=8192, n=7168, k=2048): 1137 us | throughput: 846 TFLOPS, 524 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=4096, k=7168): 2148 us | throughput: 896 TFLOPS, 344 GB/s
> Perf (num_groups= 8, expected_m_per_group=4096, n=7168, k=2048): 1136 us | throughput: 847 TFLOPS, 576 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=4096, k=7168): 595 us | throughput: 808 TFLOPS, 1790 GB/s
> Perf (num_groups=32, expected_m_per_group= 256, n=7168, k=2048): 315 us | throughput: 763 TFLOPS, 1916 GB/sNote that: the current TK implementation using multicast has about 15% performance regression compared to a non-multicast path. |
This pull request implements FP8 group GEMM for the contiguous case. The masked case will be added at the next step.