Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bfloat16 is slower than float32 on Intel Xeon Platinum 8481C CPU #637

Open
netw0rkf10w opened this issue May 25, 2024 · 19 comments
Open

bfloat16 is slower than float32 on Intel Xeon Platinum 8481C CPU #637

netw0rkf10w opened this issue May 25, 2024 · 19 comments
Assignees
Labels
CPU CPU specific issues Performance

Comments

@netw0rkf10w
Copy link

Describe the bug

Running the following benchmark code:

from time import time
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import intel_extension_for_pytorch as ipex


def get_embed_function(model_name):
    model = SentenceTransformer(model_name)
    model.eval()
    def embed_function(sentences):
        with torch.inference_mode():
            x = model.encode(sentences)
        return x

    return embed_function

def get_embed_function_optimized(model_name):
    model = SentenceTransformer(model_name)
    model.eval()
    model = ipex.optimize(model, dtype=torch.bfloat16)
    def embed_function(sentences):
        with torch.inference_mode(), torch.cpu.amp.autocast():
            x = model.encode(sentences)
        return x

    return embed_function

@torch.inference_mode()
def commpute_embeddings(func, sentences, warmup=True, runs=1):
    latency_warmup = None
    if warmup:
        start = time()
        func(sentences)
        latency_warmup = time() - start
    latency = 0
    for _ in range(runs):
        start = time()
        embeddings = func(sentences)
        latency += time() - start
    return embeddings, latency / runs, latency_warmup


def benchmark():
    sentences = ["Un avion est en train de décoller.",
            "Un homme joue d'une grande flûte.",
            "Un homme étale du fromage râpé sur une pizza.",
            "Une personne jette un chat au plafond.",
            "Une personne est en train de plier un morceau de papier.",
            ]

    runs = 5

    # model_name = "all-MiniLM-L6-v2"
    model_name = "dangvantuan/sentence-camembert-large"

    e_sbert, t, tw = commpute_embeddings(get_embed_function(model_name), sentences, warmup=True, runs=runs)
    print(f'sbert took: {t}, warmup: {tw}')

    e_sbert_opt, t, tw = commpute_embeddings(get_embed_function_optimized(model_name), sentences, warmup=True, runs=runs)
    print(f'optimized sbert took: {t}, warmup: {tw}')

    print(f'diff = {np.linalg.norm(np.array(e_sbert) - np.array(e_sbert_opt), axis=1)}')

if __name__ == "__main__":
    benchmark()

I obtained (warnings excluded):

sbert took: 0.27114005088806153, warmup: 0.275803804397583
optimized sbert took: 0.2937626838684082, warmup: 0.2977313995361328
diff = [0.0784898  0.09684175 0.10466592 0.10803085 0.0929962 ]

Is this a bug or did I do something wrong?

The full output is below for your information:

2024-05-25 13:54:27,311 - SentenceTransformer.py - sentence_transformers.SentenceTransformer - WARNING - No sentence-transformers model found with name dangvantuan/sentence-camembert-large. Creating a new one with MEAN pooling.
/home/all/miniconda3/envs/env2/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
sbert took: 0.27114005088806153, warmup: 0.275803804397583
2024-05-25 13:54:29,823 - SentenceTransformer.py - sentence_transformers.SentenceTransformer - WARNING - No sentence-transformers model found with name dangvantuan/sentence-camembert-large. Creating a new one with MEAN pooling.
2024-05-25 13:54:31,237 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,237 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,237 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,237 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,237 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,531 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,531 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,531 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,531 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,531 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,824 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,825 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,825 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,825 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:31,825 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,118 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,118 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,118 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,118 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,118 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,412 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,412 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,412 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,412 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,412 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,706 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,706 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,706 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,706 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-25 13:54:32,706 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
optimized sbert took: 0.2937626838684082, warmup: 0.2977313995361328
diff = [0.0784898  0.09684175 0.10466592 0.10803085 0.0929962 ]

Versions

Collecting environment information...
PyTorch version: 2.3.0+cpu
PyTorch CXX11 ABI: No
IPEX version: 2.3.0+cpu
IPEX commit: 3ac92c8
Build type: Release

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: N/A
Clang version: N/A
IGC version: N/A
CMake version: N/A
Libc version: glibc-2.31

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.0-29-cloud-amd64-x86_64-with-glibc2.31
Is XPU available: False
DPCPP runtime version: N/A
MKL version: N/A
GPU models and configuration: 

Intel OpenCL ICD version: N/A
Level Zero version: N/A

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Byte Order:                           Little Endian
Address sizes:                        52 bits physical, 57 bits virtual
CPU(s):                               4
On-line CPU(s) list:                  0-3
Thread(s) per core:                   2
Core(s) per socket:                   2
Socket(s):                            1
NUMA node(s):                         1
Vendor ID:                            GenuineIntel
CPU family:                           6
Model:                                143
Model name:                           Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
Stepping:                             8
CPU MHz:                              2699.998
BogoMIPS:                             5399.99
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            96 KiB
L1i cache:                            64 KiB
L2 cache:                             4 MiB
L3 cache:                             105 MiB
NUMA node0 CPU(s):                    0-3
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize arch_capabilities

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.3.0
[pip3] mypy==1.7.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.2
[pip3] torch==2.3.0+cpu
[pip3] torchaudio==2.3.0+cpu
[pip3] torchlibrosa==0.1.0
[pip3] torchvision==0.18.0+cpu
[conda] intel-extension-for-pytorch 2.3.0                    pypi_0    pypi
[conda] numpy                     1.26.2                   pypi_0    pypi
[conda] torch                     2.3.0+cpu                pypi_0    pypi
[conda] torchaudio                2.3.0+cpu                pypi_0    pypi
[conda] torchlibrosa              0.1.0                    pypi_0    pypi
[conda] torchvision               0.16.2                   pypi_0    pypi
@jgong5
Copy link
Contributor

jgong5 commented May 27, 2024

It is not expected that the "Linear" ops are slower but there could be something else going on in the model which slows things down. Is it possible to share the PyTorch profiles for fp32 and bf16 runs?

@netw0rkf10w
Copy link
Author

@jgong5 Thanks for your reply.

I have realized that bfloat16 is faster for batchsize=1 (tested with the first sentence in the above example).

The profiling results (together with the full output logs just in case you need them) are the following for both cases:

Batch size 1:

STAGE:2024-05-27 07:52:03 1240:1240 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
2024-05-27 07:52:03,377 - SentenceTransformer.py - sentence_transformers.SentenceTransformer - WARNING - No sentence-transformers model found with name dangvantuan/sentence-camembert-large. Creating a new one with MEAN pooling.
/home/all/miniconda3/envs/env2/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
STAGE:2024-05-27 07:52:05 1240:1240 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-27 07:52:05 1240:1240 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
sbert took: 0.15806293487548828, warmup: 0.16187143325805664
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               aten::linear         0.48%       5.481ms        71.57%     813.928ms     935.549us           870  
                                aten::addmm        70.23%     798.706ms        70.81%     805.296ms     925.628us           870  
                                aten::copy_        20.51%     233.233ms        20.51%     233.233ms     145.498us          1603  
                                aten::empty         3.31%      37.681ms         3.31%      37.681ms      19.245us          1958  
                           aten::layer_norm         0.16%       1.868ms         1.47%      16.724ms      56.884us           294  
                     torch_ipex::layer_norm         1.36%      15.494ms         1.42%      16.147ms      54.922us           294  
                                 aten::gelu         1.10%      12.524ms         1.10%      12.524ms      86.972us           144  
                               aten::matmul         0.24%       2.762ms         0.81%       9.223ms      32.024us           288  
                                  aten::bmm         0.41%       4.628ms         0.41%       4.628ms      16.069us           288  
                              aten::softmax         0.05%     539.000us         0.32%       3.669ms      25.479us           144  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.137s

STAGE:2024-05-27 07:52:07 1240:1240 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
2024-05-27 07:52:07,332 - SentenceTransformer.py - sentence_transformers.SentenceTransformer - WARNING - No sentence-transformers model found with name dangvantuan/sentence-camembert-large. Creating a new one with MEAN pooling.
2024-05-27 07:52:08,520 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:08,594 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:08,667 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:08,740 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:08,814 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:08,888 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
STAGE:2024-05-27 07:52:08 1240:1240 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-27 07:52:08 1240:1240 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
optimized sbert took: 0.07358126640319824, warmup: 0.07696318626403809
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                   aten::copy_        48.10%     474.969ms        48.10%     474.969ms     221.224us          2147  
                       torch_ipex::ipex_linear        31.80%     314.047ms        35.03%     345.934ms     323.605us          1069  
                                   aten::clone         0.42%       4.131ms        20.44%     201.832ms     208.720us           967  
                  ipex_prepack::linear_prepack         0.09%     915.000us        10.31%     101.785ms     701.966us           145  
    ipex_prepack::createLinearPrePackOpContext        10.19%     100.592ms        10.21%     100.870ms     695.655us           145  
                                      aten::to         1.87%      18.514ms         9.90%      97.769ms      17.116us          5712  
                                aten::_to_copy         0.23%       2.257ms         9.83%      97.096ms     123.375us           787  
                                  aten::matmul         0.39%       3.876ms         2.54%      25.106ms      67.128us           374  
                           aten::empty_strided         1.84%      18.144ms         1.84%      18.144ms      12.326us          1472  
                                     aten::bmm         1.19%      11.763ms         1.68%      16.636ms      57.764us           288  
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 987.516ms

diff = [0.07849015]

Batch size 5:

STAGE:2024-05-27 07:52:46 1274:1274 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
2024-05-27 07:52:46,928 - SentenceTransformer.py - sentence_transformers.SentenceTransformer - WARNING - No sentence-transformers model found with name dangvantuan/sentence-camembert-large. Creating a new one with MEAN pooling.
/home/all/miniconda3/envs/env2/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
STAGE:2024-05-27 07:52:49 1274:1274 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-27 07:52:49 1274:1274 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
sbert took: 0.2906981945037842, warmup: 0.3036918640136719
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               aten::linear         0.28%       5.602ms        78.24%        1.542s       1.772ms           870  
                                aten::addmm        76.32%        1.504s        77.79%        1.533s       1.762ms           870  
                                aten::copy_        13.25%     261.052ms        13.25%     261.052ms     127.904us          2041  
                                aten::empty         3.51%      69.109ms         3.51%      69.109ms      28.843us          2396  
                                 aten::gelu         2.04%      40.180ms         2.04%      40.180ms     279.028us           144  
                              aten::softmax         0.10%       1.883ms         1.35%      26.565ms     184.479us           144  
                             aten::_softmax         1.34%      26.318ms         1.34%      26.318ms     182.764us           144  
                           aten::layer_norm         0.10%       1.946ms         1.18%      23.341ms      79.391us           294  
                               aten::matmul         0.13%       2.491ms         1.18%      23.298ms      80.896us           288  
                     torch_ipex::layer_norm         1.01%      19.834ms         1.16%      22.776ms      77.469us           294  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.971s

STAGE:2024-05-27 07:52:51 1274:1274 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
2024-05-27 07:52:51,747 - SentenceTransformer.py - sentence_transformers.SentenceTransformer - WARNING - No sentence-transformers model found with name dangvantuan/sentence-camembert-large. Creating a new one with MEAN pooling.
2024-05-27 07:52:53,276 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,276 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,276 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,276 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,277 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,575 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,575 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,575 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,575 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,576 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,875 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,875 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,875 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,875 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:53,875 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,175 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,175 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,175 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,175 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,175 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,475 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,475 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,475 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,475 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,475 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,775 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,775 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,775 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,775 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-05-27 07:52:54,776 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
STAGE:2024-05-27 07:52:54 1274:1274 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-05-27 07:52:54 1274:1274 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
optimized sbert took: 0.29979958534240725, warmup: 0.3086233139038086
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       torch_ipex::ipex_linear        66.96%        1.668s        67.20%        1.674s       1.850ms           905  
                                   aten::copy_        21.61%     538.511ms        21.61%     538.511ms     246.684us          2183  
                                   aten::clone         0.13%       3.191ms         9.41%     234.400ms     239.428us           979  
                                      aten::to         0.46%      11.345ms         4.20%     104.568ms      18.154us          5760  
                                aten::_to_copy         0.10%       2.451ms         4.17%     103.808ms     128.000us           811  
                  ipex_prepack::linear_prepack         0.04%       1.049ms         4.09%     102.005ms     703.483us           145  
    ipex_prepack::createLinearPrePackOpContext         4.04%     100.643ms         4.05%     100.956ms     696.248us           145  
                                   aten::empty         1.43%      35.530ms         1.43%      35.530ms       7.754us          4582  
                                  aten::matmul         0.19%       4.843ms         1.40%      34.862ms     111.380us           313  
                           aten::empty_strided         1.37%      34.103ms         1.37%      34.103ms      22.796us          1496  
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.492s

diff = [0.0784898  0.09684175 0.10466592 0.10803085 0.0929962 ]

@jgong5
Copy link
Contributor

jgong5 commented May 27, 2024

cc @zhuhaozhe
@netw0rkf10w Do you mind share the oneDNN verbose log with ONEDNN_VERBOSE=1?

@netw0rkf10w
Copy link
Author

@jgong5 Yes the full output can be found in the attached file.
ipex_bfloat16_bug.txt

@jgong5
Copy link
Contributor

jgong5 commented May 27, 2024

@jgong5 Yes the full output can be found in the attached file. ipex_bfloat16_bug.txt

Weird, it seems kernels were running with avx512-bf16, not amx. @zhuhaozhe can you double check?

@zhuhaozhe
Copy link
Contributor

@jgong5 Yes the full output can be found in the attached file. ipex_bfloat16_bug.txt

Weird, it seems kernels were running with avx512-bf16, not amx. @zhuhaozhe can you double check?

Sure, I will check it.

@xiguiw xiguiw added CPU CPU specific issues Performance labels May 28, 2024
@zhuhaozhe
Copy link
Contributor

zhuhaozhe commented May 28, 2024

Hi, @jgong5.
From the Flags provided by reporter, there is no AMX included.

Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize arch_capabilities

while on the SPR I can access
there are more flags about amx amx_bf16 avx512_fp16 amx_tile amx_int8

But 8481C should be Sapphire Rapids. Do you know where we can get such a hardware and try to re-produce on it?

@netw0rkf10w
Copy link
Author

@zhuhaozhe Thanks for your investigation. It's a c3-standard-4 VM instance from Google Cloud. Please let me know if you need further information.

@zhuhaozhe
Copy link
Contributor

@zhuhaozhe Thanks for your investigation. It's a c3-standard-4 VM instance from Google Cloud. Please let me know if you need further information.

Thanks for clarification.
Hi, @jingxu10, do you know where we can access such a VM instance?

@jingxu10
Copy link
Contributor

@zhuhaozhe SPR instances in CSP might not have AMX onboard. 8481C seems to be used in GCP.
https://cloud.google.com/compute/docs/cpu-platforms

@WilliamTambellini
Copy link

Your linux kernel is also too old for amx detection

@zhuhaozhe
Copy link
Contributor

Thanks for the document @jingxu10.

Your linux kernel is also too old for amx detection

This is correct, the document mentioned https://cloud.google.com/compute/docs/cpu-platforms#intel-amx we need linux kernel > 5.16 to enable amx.

@netw0rkf10w May you help to check your kernel version by uname -r?
From here looks like it is 5.10

Python platform: Linux-5.10.0-29-cloud-amd64-x86_64-with-glibc2.31

@netw0rkf10w
Copy link
Author

netw0rkf10w commented May 29, 2024

@WilliamTambellini Good catch! Thanks!
@zhuhaozhe It's indeed 5.10. I'll upgrade the OS and get.

@netw0rkf10w
Copy link
Author

netw0rkf10w commented Jun 12, 2024

I confirm that after upgrading the Linux kernel to 6.1.0, I obtained improved performance:

sbert took: 0.27169055938720704, warmup: 0.2765231132507324
optimized sbert took: 0.11300868988037109, warmup: 0.11607575416564941

@zhuhaozhe Is the 2x improvement is as expected or could it be improved further? Thanks!

@zhuhaozhe
Copy link
Contributor

@netw0rkf10w, 2.6x improvement looks good to me, can you share profiling results as well?

@netw0rkf10w
Copy link
Author

@zhuhaozhe Yes, the full outputs are below (the first block is without ipex).

By the way, could you please tell me if the WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input. is normal? Thanks a lot!

$ python intel_bug.py 
STAGE:2024-06-13 15:35:40 4820:4820 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
2024-06-13 15:35:40,830 - SentenceTransformer.py - sentence_transformers.SentenceTransformer - WARNING - No sentence-transformers model found with name dangvantuan/sentence-camembert-large. Creating a new one with MEAN pooling.
/home/all/miniconda3/envs/env2/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
STAGE:2024-06-13 15:35:51 4820:4820 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-06-13 15:35:51 4820:4820 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::copy_        77.67%        7.553s        77.67%        7.553s       4.712ms          1603  
                               aten::linear         0.07%       6.938ms        17.23%        1.676s       1.926ms           870  
                                aten::addmm        16.99%        1.652s        17.12%        1.665s       1.914ms           870  
                                 aten::item         0.07%       6.875ms         3.17%     308.594ms     787.230us           392  
                  aten::_local_scalar_dense         3.17%     308.350ms         3.17%     308.350ms     786.607us           392  
                                 aten::gelu         0.44%      42.788ms         0.44%      42.788ms     297.139us           144  
                                aten::empty         0.37%      36.206ms         0.37%      36.206ms      18.491us          1958  
                           aten::layer_norm         0.02%       2.167ms         0.37%      35.913ms     122.153us           294  
                     torch_ipex::layer_norm         0.36%      34.642ms         0.36%      35.387ms     120.364us           294  
                               aten::matmul         0.04%       3.577ms         0.22%      21.840ms      75.833us           288  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 9.724s

STAGE:2024-06-13 15:35:52 4820:4820 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
2024-06-13 15:35:52,953 - SentenceTransformer.py - sentence_transformers.SentenceTransformer - WARNING - No sentence-transformers model found with name dangvantuan/sentence-camembert-large. Creating a new one with MEAN pooling.
2024-06-13 15:35:54,838 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-06-13 15:35:54,930 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-06-13 15:35:55,021 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-06-13 15:35:55,112 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-06-13 15:35:55,202 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
2024-06-13 15:35:55,292 - _logger.py - IPEX - WARNING - calling in ipex numpy which is not share memory with torch tensor for bfloat16 input.
STAGE:2024-06-13 15:35:55 4820:4820 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-06-13 15:35:55 4820:4820 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                   aten::copy_        53.20%     830.317ms        53.20%     830.317ms     386.734us          2147  
                       torch_ipex::ipex_linear        26.00%     405.820ms        27.18%     424.165ms     424.590us           999  
                                   aten::clone         0.19%       3.016ms        19.86%     309.966ms     320.544us           967  
                  ipex_prepack::linear_prepack         0.13%       2.066ms        11.87%     185.259ms       1.278ms           145  
    ipex_prepack::createLinearPrePackOpContext        11.71%     182.840ms        11.74%     183.193ms       1.263ms           145  
                                      aten::to         1.56%      24.379ms        10.69%     166.914ms      29.222us          5712  
                                aten::_to_copy         0.21%       3.285ms        10.65%     166.210ms     211.194us           787  
                                  aten::matmul         0.25%       3.964ms         2.94%      45.885ms     124.014us           370  
                                     aten::bmm         1.89%      29.469ms         2.35%      36.611ms     127.122us           288  
                                 aten::softmax         0.04%     641.000us         1.13%      17.574ms     122.042us           144  
----------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.561s

diff = [0.09288985]

@zhuhaozhe
Copy link
Contributor

Thanks for your input, @netw0rkf10w.
Looks like your two profiling table are not apple to apple so I can not see the actually speed up.
Usually we expected gemm,( like bmm, matmul, linear in your profiling benefits ) from Intel AMX. And we expected memory ops benefits from lower memory footprint (16 bit for BF16 and 32 bit for FP32). https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-pytorch-training-inference-on-amx.html.

The warning msg

indicate your are calling t.numpy() to convert Pytorch Tensor t (with dtype=torch.bfloat16) to a numpy tensor and since numpy dose not support bf16, we convert it to float32 first. This is a cost for using BF16, but we need more detail info to know whether it is a very large cost.

@netw0rkf10w
Copy link
Author

@zhuhaozhe Thanks a lot for your reply!
Could you please tell me what details do you need to have an accurate assessment? As you can see from the code, the difference between the two version is just model = ipex.optimize(model, dtype=torch.bfloat16) (and torch.cpu.amp.autocast()), so it's not clear to me why it's not an apple-to-apple comparison.

By the way, I would like to ask another question if you don't mind. I would like to serve this model with OpenVINO Model Server. However, the latter only supports model formats that don't seem to support bfloat16. Could you please tell me how to serve a model optimized with IPEX? Thank you very much in advance for your help!

@zhuhaozhe
Copy link
Contributor

Hi, @netw0rkf10w.
When we compare bf16/fp32 performance from profiling, we usually expect the num of calls should be the same. Then we can compare the "CPU time avg" to see the speed up. However, for your workload, it may have dynamic status so the number of calls are different. Then we actually do not have enough information to tell the exactly speed up here.

And for OpenVino Model Server, I have not used it. Hi, @jingxu10. Do you have more infos here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CPU CPU specific issues Performance
Projects
None yet
Development

No branches or pull requests

6 participants