Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPEX.LLM performance on Intel SPR with large batch size. #581

Open
septicmk opened this issue Apr 2, 2024 · 3 comments
Open

IPEX.LLM performance on Intel SPR with large batch size. #581

septicmk opened this issue Apr 2, 2024 · 3 comments
Labels

Comments

@septicmk
Copy link

septicmk commented Apr 2, 2024

Describe the issue

The inference of LLM has always been costly, so recently I have been studying whether it is possible to use a CPU for inference (from a hardware point of view, the memory-bound decode stage seems to be very suitable for the CPU and its large memory). So I tested the performance of IPEX.LLM on an Intel SPR machine and found that its throughput quickly saturates and cannot continue to grow with the increase of batch size (it is close to saturation at about batch_size=8), and even the throughput will severely decline with the increase of batch size.

I tried to use roofline model to analyze the upper limit of batch size and found that IPEX.LLM has a large gap from the theoretical peak, but on the GPU (based on vllm), this saturated batch_size is very large and conforms to the roofline model, and even if the throughput is saturated, it will not decline.

Varying batch size

Experiment setting:

  • IPEX==v2.2.0+cpu, pytorch==2.2.0+cpu.
  • on a server equipped with Intel Xeon(Sapphire Rapids) Platinum 6462C (64 physical cores with 2 sockets).
  • on this server, oneDNN matmul_perf report 30 TFLOPS on BF16 GEMM benchmark, 9 TFlOPS on FP32 GEMM benchmark.
  • varying batch size in [1, 2, 4, 8, 16, 32, 64, 128],on different input length in [128, 512, 1024, 2048]
  • Run benchmark with:
deepspeed --bind_cores_to_rank \
  run.py \
    --num-warmup=2 \
    --num-iter=10 \
    --input-tokens=${INPUT_LENGTH} \
    --max-new-tokens=200 \
    --batch-size=${BATCH_SIZE} \
    --token-latency \
    --ipex \
    --autotp \
    --benchmark \
    --dtype bfloat16 \
    -m ./quant/llama2-72b-local_shard/llama_local_shard

batching2
Below is the raw data at input_length=2048

b= 1 2 4 8 16 32 64 128
TTFT(s) 25.047 53.501 108.545 216.014 435.736 880.231 1749.547 3412.132
TPOT(s) 0.660 0.714 0.875 1.156 2.694 6.641 19.374 42.111
  • Observation:
    • TTFT is prefill latency, TPOT is decode latency.
    • Prefill: It is basically linear, with the computing power always in a saturated state. A doubling of the batch size implies a doubling of the TTFT.
    • Decode:
      • For batch size <8, the TPOT hardly increases.
      • However, when the batch size increases from 8 to 64 (8x), the TPOT increases from 1.15 to 19.37 (16x). This means that a batch size of 64 is not as good as breaking it down into 8 runs with a batch size of 8. Does this mean that the batch size is already saturated at 8?

Because the performance of the decode stage surprised me, I then used Roofline to study its performance behavior.

Roofline (decode stage)

Experiment setting:

  • Same with above, but on a llama2-7b instead of llama2-72b. (single instance)
  • Compute FLOPS and IO based on this blog

roofline

Observation:

  • The decode can be divided into the MLP part (including the QKV stage and the FFN at the end of each Transformer layer) and the Attention part (SDPA).
  • The MLP part conforms to the roofline model and can reach about 28T FLOPS, which is close to the performance reported by the bf16 GEMM kernel.
  • The performance of the SDPA part cannot be improved. It neither fully utilizes the bandwidth nor the computing power. It only has a performance of 700 GFLOPS.
  • The SDPA directly led to a decline in decode performance

SDPA kernel

SDPA implementation:

  • The most frequently called kernels are reduce_head and mul_attenion_weights_and_value_of_head,they are dot product.
  • Their current implementation are first cast to fp32, then perform fma. Perhaps this is the reason for the early saturation of computation power?

DP Kernel benchmark:

  • Taking Reduce_head as an example, I've simply compared the performance of dot product:

  • Version 0: cast fp16 to fp32 then do fp32fma (current IPEX behavior)

    • void cvt_fma_bf16(const uint16_t* a, const uint16_t* b, float* c, size_t n) {
        size_t head_size = n;
        size_t vec_size = 256 / 8 / sizeof(uint16_t);
        auto q_ptr_start = a;
        auto k_ptr_start = b;
        auto sum_vec = _mm512_setzero_ps();
        for (size_t hsi = 0; hsi <= head_size - vec_size; hsi += vec_size) {
          // load 16 bfloat16 query from q_ptr_start and convert to 16 float32 values
          auto q_vec_bf16 = _mm256_loadu_si256((__m256i*)(q_ptr_start + hsi));
          auto q_vec_fp32 = convert_bf16_to_fp32(q_vec_bf16);
          // load 16 bfloat16 key from k_ptr_start and convert to 16 float32 values
          auto k_vec_bf16 = _mm256_loadu_si256((__m256i*)(k_ptr_start + hsi));
          auto k_vec_fp32 = convert_bf16_to_fp32(k_vec_bf16);
          sum_vec = _mm512_fmadd_ps(q_vec_fp32, k_vec_fp32, sum_vec);
        }
        c[0] += _mm512_reduce_add_ps(sum_vec);
      }
      
  • Version 1: pure fp32 fma

    • void fma_f32(const float* a, const float* b, float* c, size_t n) {
        size_t vec_size = 512 / 8 / sizeof(float);
        auto sum_vec = _mm512_setzero_ps();
        for (int i = 0; i <= n - vec_size; i += vec_size) {
          __m512 q = _mm512_load_ps(a + i);
          __m512 k = _mm512_load_ps(b + i);
      
          sum_vec = _mm512_fmadd_ps(q, k, sum_vec);
        }
        c[0] += _mm512_reduce_add_ps(sum_vec);
      }
      
  • Version 2: pure fp16 fma

    • void fma_fp16(const uint16_t* a, const uint16_t* b, float* c, size_t n) {
        size_t head_size = n;
        size_t vec_size = 512 / 8 / sizeof(uint16_t);
        auto sum_vec = _mm512_setzero_ph();
        for (int i = 0; i <= n - vec_size; i += vec_size) {
          auto q = _mm512_loadu_ph(a + i);
          auto k = _mm512_loadu_ph(b + i);
          sum_vec = _mm512_fmadd_ph(q, k, sum_vec);
        }
        c[0] += _mm512_reduce_add_ph(sum_vec);
      }
      
  • Version 3: pure dp16 fma with __mm512_dpbf16_ps

    • void fma_bf16(const uint16_t* a, const uint16_t* b, float *c, size_t n) {
        size_t head_size = n;
        size_t vec_size = 512 / 8 / sizeof(uint16_t);
        auto sum_vec = _mm512_setzero_ps();
        for (int i = 0; i <= n - vec_size; i += vec_size) {
          auto q = _mm512_loadu_si512(a + i);
          auto k = _mm512_loadu_si512(b + i);
          sum_vec = _mm512_dpbf16_ps(sum_vec, (__m512bh)q, (__m512bh)k);
        }
        c[0] += _mm512_reduce_add_ps(sum_vec);
      }
      

kernel

The performance of all these kernels are far below 30 TFLops (bf16 PEAK), thus deocde stage do not take full advantage of the computation power enhancements brought by AVX512_BF16 or AMX, causing decode to approach computation power saturation at a very small batch size.

My questions

  • Is my experiment setup correct, and is this the expected performance of IPEX?
  • Does IPEX have the opportunity to perform inference with large batches on the CPU? (i.e., batch size=256). On vllm, GPU does not show significant increase in TPOT when batch_size increases from 1 to 128.
  • Is there a chance to optimize SDPA? For example, to boost its performance from 700 GFLOPS to ~10 TFLOPS?
@jgong5
Copy link
Contributor

jgong5 commented Apr 3, 2024

@septicmk Thanks for such a comprehensive perf analysis. I guess the SDPA kernels we use for the decoding stage are designed for small batches, to bring a good latency for responsiveness. Inference with small batches is a typical case for CPU under the latency requirements. May I know your particular use cases for larger batch sizes? Continuous batching?

I have been studying whether it is possible to use a CPU for inference (from a hardware point of view, the memory-bound decode stage seems to be very suitable for the CPU and its large memory).

To clarify, the decoding stage is memory-bound with small batches. When the batch sizes become larger, the linear layers would become compute-bound while SDPA part is still memory-bound. You can get this conclusion from the good reference you shared here too. Also, "large memory" is about memory capacity, not about memory bandwidth for which GPU still has advantage. But CPU does bring fairly good LLM inference speed, specially for small-medium sized models.

To answer your questions:

Is my experiment setup correct, and is this the expected performance of IPEX?

I'd like to understand your setup better. I noticed you were using deepspeed on a two-socket machine. Is it with tensor parallel? What OMP runtime (GOMP or IOMP) and memory allocator (glibc, jemalloc or tcmalloc) did you use?

Does IPEX have the opportunity to perform inference with large batches on the CPU? (i.e., batch size=256). On vllm, GPU does not show significant increase in TPOT when batch_size increases from 1 to 128.
Is there a chance to optimize SDPA? For example, to boost its performance from 700 GFLOPS to ~10 TFLOPS?

For SDPA, yes. As I pointed out, we are optimizing for small-batch cases. We are not fusing the bmm+softmax+bmm in the kernel. When batch size become larger, the activations become larger which can benefit more from fusion. From your roofline analysis, the SDPA kernels are far from the memory bandwidth peak. This is what we should improve further. But I don't think we should use TFLOPS as a measurement. It is memory-bound. Memory bandwidth is more appropriate. cc @liangan1 who developed the SDPA kernel.

@septicmk
Copy link
Author

septicmk commented Apr 3, 2024

@jgong5 Thank you for your reply. I believe the CPU has the potential to replace the GPU as a cost-efficient solution for LLM inference. Although the GPU is fast, it has its own problems, such as its memory capacity not supporting a sufficient batch size to saturate computational power. Currently, I have a 70B scale LLM inference service deployed on a GPU cluster. User inputs are about 2048 tokens, and the GPU memory is already close to 100% with a batch size of about 40. This batch size is far from reaching the peak computational power of the A100/A800. When saying "Large memory", I'm referring to memory capacity. I think this is an opportunity for CPU inference.

Motivation for Large batch size: Using a small batch size on the CPU can indeed result in lower latency, which makes sense. However, from a cost-efficiency perspective, as long as the SLO is met, we can trade latency for throughput.

  • For online inference, under the constraints of meeting SLO, larger batches mean more serving capacity.
  • For offline inference, batch size should be maximized until both bandwidth and computational power are saturated.

My concern about large batch sizes comes from the indication of the roofline model that IPEX's batch size can be larger and its throughput can be higher. But as for the current tests, on small batch sizes (<=8), yes, IPEX runs well, but it could run better on larger batch sizes. Perhaps the following comparison of decode performance behavior with the GPU will make this issue more visual:

roofline-vsGPU

  • the above figure shows the change in decode performance of a 70b model on SPR and GPU with varying batch sizes (the two lines represent input lengths of 1024 and 2048 respectively)
  • For GPU (vllm), as you said, the decode can gradually become compute-intensive by increasing the batch size. However, due to memory size limitations, the batch size cannot be further increased.
  • For CPU (IPEX), this growth TREND appears abnormal when the batch size > 8, and it even decreases. Increasing the batch size actually becomes less efficient, which I think is unreasonable.

I'd like to understand your setup better. I noticed you were using deepspeed on a two-socket machine. Is it with tensor parallel? What OMP runtime (GOMP or IOMP) and memory allocator (glibc, jemalloc or tcmalloc) did you use?

Currently, it's tp=2, I have also tested the case of tp=1, and it also cannot continue to improve around batch_size=8

I used the mlperf setting for OMP and allocator

export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libiomp5.so
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so

we are optimizing for small-batch cases. We are not fusing the bmm+softmax+bmm in the kernel. When batch size become larger, the activations become larger which can benefit more from fusion.

SPR provides such high FLOPS. In this post, the 8480+ CPU can provide ~150 TFLOPS of bf16 peak performance, while A100 provides ~300 TFLOPS of half-precision performance. To make good use of the hardware performance (and SMID instruction set), I believe the most intuitive method is to increase the batch size (If SLO allows). GPU is limited by memory capacity, so people are trying to quant KV/weight/activation or GQA to reduce memory usage. But CPU has no such memory capacity limit, and it should be able to do better. I am very much looking forward to some optimizations (like kernel fusion as you said) in the follow-up to SDPA, which I believe can unleash the true power of SPR.

@jgong5
Copy link
Contributor

jgong5 commented Apr 4, 2024

Currently, I have a 70B scale LLM inference service deployed on a GPU cluster. User inputs are about 2048 tokens, and the GPU memory is already close to 100% with a batch size of about 40. This batch size is far from reaching the peak computational power of the A100/A800. When saying "Large memory", I'm referring to memory capacity. I think this is an opportunity for CPU inference.

Motivation for Large batch size: Using a small batch size on the CPU can indeed result in lower latency, which makes sense. However, from a cost-efficiency perspective, as long as the SLO is met, we can trade latency for throughput.

  • For online inference, under the constraints of meeting SLO, larger batches mean more serving capacity.
  • For offline inference, batch size should be maximized until both bandwidth and computational power are saturated.

Thanks for the sharing. Yes, the large memory capacity is indeed an advantage with CPU. May I know the minimal latency requirements for online inference from your side?

I used the mlperf setting for OMP and allocator

That sounds good to me.

SPR provides such high FLOPS. In this post, the 8480+ CPU can provide ~150 TFLOPS of bf16 peak performance, while A100 provides ~300 TFLOPS of half-precision performance. To make good use of the hardware performance (and SMID instruction set), I believe the most intuitive method is to increase the batch size (If SLO allows). GPU is limited by memory capacity, so people are trying to quant KV/weight/activation or GQA to reduce memory usage. But CPU has no such memory capacity limit, and it should be able to do better. I am very much looking forward to some optimizations (like kernel fusion as you said) in the follow-up to SDPA, which I believe can unleash the true power of SPR.

Thanks. We should definitively fill the gap from the roofline. One thing to clarify (again) is that the bottleneck you pointed out is in SDPA which is memory-bound instead of compute-bound. The bf16 TOPS doesn't come into play here. Instead, we should strive to a good memory bandwidth utilization.

@jingxu10 jingxu10 added CPU CPU specific issues Performance LLM labels Apr 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants