-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IPEX.LLM performance on Intel SPR with large batch size. #581
Comments
@septicmk Thanks for such a comprehensive perf analysis. I guess the SDPA kernels we use for the decoding stage are designed for small batches, to bring a good latency for responsiveness. Inference with small batches is a typical case for CPU under the latency requirements. May I know your particular use cases for larger batch sizes? Continuous batching?
To clarify, the decoding stage is memory-bound with small batches. When the batch sizes become larger, the linear layers would become compute-bound while SDPA part is still memory-bound. You can get this conclusion from the good reference you shared here too. Also, "large memory" is about memory capacity, not about memory bandwidth for which GPU still has advantage. But CPU does bring fairly good LLM inference speed, specially for small-medium sized models. To answer your questions:
I'd like to understand your setup better. I noticed you were using deepspeed on a two-socket machine. Is it with tensor parallel? What OMP runtime (GOMP or IOMP) and memory allocator (glibc, jemalloc or tcmalloc) did you use?
For SDPA, yes. As I pointed out, we are optimizing for small-batch cases. We are not fusing the bmm+softmax+bmm in the kernel. When batch size become larger, the activations become larger which can benefit more from fusion. From your roofline analysis, the SDPA kernels are far from the memory bandwidth peak. This is what we should improve further. But I don't think we should use TFLOPS as a measurement. It is memory-bound. Memory bandwidth is more appropriate. cc @liangan1 who developed the SDPA kernel. |
@jgong5 Thank you for your reply. I believe the CPU has the potential to replace the GPU as a cost-efficient solution for LLM inference. Although the GPU is fast, it has its own problems, such as its memory capacity not supporting a sufficient batch size to saturate computational power. Currently, I have a 70B scale LLM inference service deployed on a GPU cluster. User inputs are about 2048 tokens, and the GPU memory is already close to 100% with a batch size of about 40. This batch size is far from reaching the peak computational power of the A100/A800. When saying "Large memory", I'm referring to memory capacity. I think this is an opportunity for CPU inference. Motivation for Large batch size: Using a small batch size on the CPU can indeed result in lower latency, which makes sense. However, from a cost-efficiency perspective, as long as the SLO is met, we can trade latency for throughput.
My concern about large batch sizes comes from the indication of the roofline model that IPEX's batch size can be larger and its throughput can be higher. But as for the current tests, on small batch sizes (<=8), yes, IPEX runs well, but it could run better on larger batch sizes. Perhaps the following comparison of decode performance behavior with the GPU will make this issue more visual:
Currently, it's tp=2, I have also tested the case of tp=1, and it also cannot continue to improve around batch_size=8 I used the mlperf setting for OMP and allocator
SPR provides such high FLOPS. In this post, the 8480+ CPU can provide ~150 TFLOPS of bf16 peak performance, while A100 provides ~300 TFLOPS of half-precision performance. To make good use of the hardware performance (and SMID instruction set), I believe the most intuitive method is to increase the batch size (If SLO allows). GPU is limited by memory capacity, so people are trying to quant KV/weight/activation or GQA to reduce memory usage. But CPU has no such memory capacity limit, and it should be able to do better. I am very much looking forward to some optimizations (like kernel fusion as you said) in the follow-up to SDPA, which I believe can unleash the true power of SPR. |
Thanks for the sharing. Yes, the large memory capacity is indeed an advantage with CPU. May I know the minimal latency requirements for online inference from your side?
That sounds good to me.
Thanks. We should definitively fill the gap from the roofline. One thing to clarify (again) is that the bottleneck you pointed out is in SDPA which is memory-bound instead of compute-bound. The bf16 TOPS doesn't come into play here. Instead, we should strive to a good memory bandwidth utilization. |
Describe the issue
The inference of LLM has always been costly, so recently I have been studying whether it is possible to use a CPU for inference (from a hardware point of view, the memory-bound decode stage seems to be very suitable for the CPU and its large memory). So I tested the performance of IPEX.LLM on an Intel SPR machine and found that its throughput quickly saturates and cannot continue to grow with the increase of batch size (it is close to saturation at about batch_size=8), and even the throughput will severely decline with the increase of batch size.
I tried to use roofline model to analyze the upper limit of batch size and found that IPEX.LLM has a large gap from the theoretical peak, but on the GPU (based on vllm), this saturated batch_size is very large and conforms to the roofline model, and even if the throughput is saturated, it will not decline.
Varying batch size
Experiment setting:
Below is the raw data at
input_length=2048
Because the performance of the decode stage surprised me, I then used Roofline to study its performance behavior.
Roofline (decode stage)
Experiment setting:
Observation:
SDPA kernel
SDPA implementation:
fma
. Perhaps this is the reason for the early saturation of computation power?DP Kernel benchmark:
Taking Reduce_head as an example, I've simply compared the performance of dot product:
Version 0: cast fp16 to fp32 then do fp32fma (current IPEX behavior)
Version 1: pure fp32 fma
Version 2: pure fp16 fma
Version 3: pure dp16 fma with
__mm512_dpbf16_ps
The performance of all these kernels are far below 30 TFLops (bf16 PEAK), thus deocde stage do not take full advantage of the computation power enhancements brought by AVX512_BF16 or AMX, causing decode to approach computation power saturation at a very small batch size.
My questions
The text was updated successfully, but these errors were encountered: