Description
Describe the issue
The inference of LLM has always been costly, so recently I have been studying whether it is possible to use a CPU for inference (from a hardware point of view, the memory-bound decode stage seems to be very suitable for the CPU and its large memory). So I tested the performance of IPEX.LLM on an Intel SPR machine and found that its throughput quickly saturates and cannot continue to grow with the increase of batch size (it is close to saturation at about batch_size=8), and even the throughput will severely decline with the increase of batch size.
I tried to use roofline model to analyze the upper limit of batch size and found that IPEX.LLM has a large gap from the theoretical peak, but on the GPU (based on vllm), this saturated batch_size is very large and conforms to the roofline model, and even if the throughput is saturated, it will not decline.
Varying batch size
Experiment setting:
- IPEX==v2.2.0+cpu, pytorch==2.2.0+cpu.
- on a server equipped with Intel Xeon(Sapphire Rapids) Platinum 6462C (64 physical cores with 2 sockets).
- on this server, oneDNN matmul_perf report 30 TFLOPS on BF16 GEMM benchmark, 9 TFlOPS on FP32 GEMM benchmark.
- varying batch size in [1, 2, 4, 8, 16, 32, 64, 128],on different input length in [128, 512, 1024, 2048]
- Run benchmark with:
deepspeed --bind_cores_to_rank \
run.py \
--num-warmup=2 \
--num-iter=10 \
--input-tokens=${INPUT_LENGTH} \
--max-new-tokens=200 \
--batch-size=${BATCH_SIZE} \
--token-latency \
--ipex \
--autotp \
--benchmark \
--dtype bfloat16 \
-m ./quant/llama2-72b-local_shard/llama_local_shard
Below is the raw data at input_length=2048
b= | 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128 |
---|---|---|---|---|---|---|---|---|
TTFT(s) | 25.047 | 53.501 | 108.545 | 216.014 | 435.736 | 880.231 | 1749.547 | 3412.132 |
TPOT(s) | 0.660 | 0.714 | 0.875 | 1.156 | 2.694 | 6.641 | 19.374 | 42.111 |
- Observation:
- TTFT is prefill latency, TPOT is decode latency.
- Prefill: It is basically linear, with the computing power always in a saturated state. A doubling of the batch size implies a doubling of the TTFT.
- Decode:
- For batch size <8, the TPOT hardly increases.
- However, when the batch size increases from 8 to 64 (8x), the TPOT increases from 1.15 to 19.37 (16x). This means that a batch size of 64 is not as good as breaking it down into 8 runs with a batch size of 8. Does this mean that the batch size is already saturated at 8?
Because the performance of the decode stage surprised me, I then used Roofline to study its performance behavior.
Roofline (decode stage)
Experiment setting:
- Same with above, but on a llama2-7b instead of llama2-72b. (single instance)
- Compute FLOPS and IO based on this blog
Observation:
- The decode can be divided into the MLP part (including the QKV stage and the FFN at the end of each Transformer layer) and the Attention part (SDPA).
- The MLP part conforms to the roofline model and can reach about 28T FLOPS, which is close to the performance reported by the bf16 GEMM kernel.
- The performance of the SDPA part cannot be improved. It neither fully utilizes the bandwidth nor the computing power. It only has a performance of 700 GFLOPS.
- The SDPA directly led to a decline in decode performance
SDPA kernel
SDPA implementation:
- The most frequently called kernels are reduce_head and mul_attenion_weights_and_value_of_head,they are dot product.
- Their current implementation are first cast to fp32, then perform
fma
. Perhaps this is the reason for the early saturation of computation power?
DP Kernel benchmark:
-
Taking Reduce_head as an example, I've simply compared the performance of dot product:
-
Version 0: cast fp16 to fp32 then do fp32fma (current IPEX behavior)
-
void cvt_fma_bf16(const uint16_t* a, const uint16_t* b, float* c, size_t n) { size_t head_size = n; size_t vec_size = 256 / 8 / sizeof(uint16_t); auto q_ptr_start = a; auto k_ptr_start = b; auto sum_vec = _mm512_setzero_ps(); for (size_t hsi = 0; hsi <= head_size - vec_size; hsi += vec_size) { // load 16 bfloat16 query from q_ptr_start and convert to 16 float32 values auto q_vec_bf16 = _mm256_loadu_si256((__m256i*)(q_ptr_start + hsi)); auto q_vec_fp32 = convert_bf16_to_fp32(q_vec_bf16); // load 16 bfloat16 key from k_ptr_start and convert to 16 float32 values auto k_vec_bf16 = _mm256_loadu_si256((__m256i*)(k_ptr_start + hsi)); auto k_vec_fp32 = convert_bf16_to_fp32(k_vec_bf16); sum_vec = _mm512_fmadd_ps(q_vec_fp32, k_vec_fp32, sum_vec); } c[0] += _mm512_reduce_add_ps(sum_vec); }
-
-
Version 1: pure fp32 fma
-
void fma_f32(const float* a, const float* b, float* c, size_t n) { size_t vec_size = 512 / 8 / sizeof(float); auto sum_vec = _mm512_setzero_ps(); for (int i = 0; i <= n - vec_size; i += vec_size) { __m512 q = _mm512_load_ps(a + i); __m512 k = _mm512_load_ps(b + i); sum_vec = _mm512_fmadd_ps(q, k, sum_vec); } c[0] += _mm512_reduce_add_ps(sum_vec); }
-
-
Version 2: pure fp16 fma
-
void fma_fp16(const uint16_t* a, const uint16_t* b, float* c, size_t n) { size_t head_size = n; size_t vec_size = 512 / 8 / sizeof(uint16_t); auto sum_vec = _mm512_setzero_ph(); for (int i = 0; i <= n - vec_size; i += vec_size) { auto q = _mm512_loadu_ph(a + i); auto k = _mm512_loadu_ph(b + i); sum_vec = _mm512_fmadd_ph(q, k, sum_vec); } c[0] += _mm512_reduce_add_ph(sum_vec); }
-
-
Version 3: pure dp16 fma with
__mm512_dpbf16_ps
-
void fma_bf16(const uint16_t* a, const uint16_t* b, float *c, size_t n) { size_t head_size = n; size_t vec_size = 512 / 8 / sizeof(uint16_t); auto sum_vec = _mm512_setzero_ps(); for (int i = 0; i <= n - vec_size; i += vec_size) { auto q = _mm512_loadu_si512(a + i); auto k = _mm512_loadu_si512(b + i); sum_vec = _mm512_dpbf16_ps(sum_vec, (__m512bh)q, (__m512bh)k); } c[0] += _mm512_reduce_add_ps(sum_vec); }
-
The performance of all these kernels are far below 30 TFLops (bf16 PEAK), thus deocde stage do not take full advantage of the computation power enhancements brought by AVX512_BF16 or AMX, causing decode to approach computation power saturation at a very small batch size.
My questions
- Is my experiment setup correct, and is this the expected performance of IPEX?
- Does IPEX have the opportunity to perform inference with large batches on the CPU? (i.e., batch size=256). On vllm, GPU does not show significant increase in TPOT when batch_size increases from 1 to 128.
- Is there a chance to optimize SDPA? For example, to boost its performance from 700 GFLOPS to ~10 TFLOPS?