Skip to content

A3 ultra deepseek inference sglang recipe issues #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rick-c-goog opened this issue Apr 12, 2025 · 1 comment
Open

A3 ultra deepseek inference sglang recipe issues #5

rick-c-goog opened this issue Apr 12, 2025 · 1 comment
Assignees

Comments

@rick-c-goog
Copy link

rick-c-goog commented Apr 12, 2025

A3 ultra Deepseek inference sglang helmchart not working with multiple issues:

  1. helmchart template issue:
    https://github.com/AI-Hypercomputer/gpu-recipes/blob/main/src/helm-charts/a3ultra/sglang-inference/templates/model-serve-launcher.yaml#L149
    {{- if .$rootValues.sglang.serverArgs }}
    should be corrected to
    {{- if $root.Values.sglang.serverArgs }}

  2. source set_nccl_env.sh error:
    2025-04-12 12:36:50.511 EDT
    /usr/local/gib/scripts/set_nccl_env.sh: line 16: lspci: command not found
    2025-04-12 12:36:50.513 EDT
    /usr/local/gib/scripts/set_nccl_env.sh: line 17: lspci: command not found

need to add
apt-get install --yes pciutils
to https://github.com/AI-Hypercomputer/gpu-recipes/blob/main/src/docker/sglang/sglang.Dockerfile

  1. runtime errors, related to gIB network not found, libnccl-net.so can not be loaded, in conflict with sglang nccl settings. Tried to setup the NCCL environment variables required by A3 ultra, with same errors too

kubectl logs rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp -c serving
/usr/local/gib/configs/tuner_config_a3u.txtpb
DP attention is enabled. The chunked prefill size is adjusted to 4096 to avoid MoE kernel issues. The schedule conservativeness is adjusted to 0.3. Data parallel size is adjusted to be the same as tensor parallel size.
[2025-04-12 09:00:48] server_args=ServerArgs(model_path='deepseek-ai/DeepSeek-R1', tokenizer_path='deepseek-ai/DeepSeek-R1', tokenizer_mode='auto', load_format='auto', trust_remote_code=True, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='deepseek-ai/DeepSeek-R1', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=30000, mem_fraction_static=0.81, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=4096, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=0.3, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=8, stream_interval=1, stream_output=False, random_seed=1003314541, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='sglang_storage', enable_cache_report=False, dp_size=8, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=True, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False)
A new version of the following files was downloaded from https://huggingface.co/deepseek-ai/DeepSeek-R1:

  • configuration_deepseek.py
    . Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
    [2025-04-12 09:01:02 DP6 TP6] MLA optimization is turned on. Use triton backend.
    [2025-04-12 09:01:02 DP6 TP6] Init torch distributed begin.
    [2025-04-12 09:01:03 DP5 TP5] MLA optimization is turned on. Use triton backend.
    [2025-04-12 09:01:03 DP5 TP5] Init torch distributed begin.
    [2025-04-12 09:01:03 DP4 TP4] MLA optimization is turned on. Use triton backend.
    [2025-04-12 09:01:03 DP4 TP4] Init torch distributed begin.
    [2025-04-12 09:01:03 DP2 TP2] MLA optimization is turned on. Use triton backend.
    [2025-04-12 09:01:03 DP2 TP2] Init torch distributed begin.
    [2025-04-12 09:01:03 DP1 TP1] MLA optimization is turned on. Use triton backend.
    [2025-04-12 09:01:03 DP1 TP1] Init torch distributed begin.
    [2025-04-12 09:01:03 DP0 TP0] MLA optimization is turned on. Use triton backend.
    [2025-04-12 09:01:03 DP0 TP0] Init torch distributed begin.
    [2025-04-12 09:01:04 DP7 TP7] MLA optimization is turned on. Use triton backend.
    [2025-04-12 09:01:04 DP7 TP7] Init torch distributed begin.
    [2025-04-12 09:01:06 DP3 TP3] MLA optimization is turned on. Use triton backend.
    [2025-04-12 09:01:06 DP3 TP3] Init torch distributed begin.
    [2025-04-12 09:01:07 DP0 TP0] sglang is using nccl==2.21.5
    [2025-04-12 09:01:07 DP1 TP1] sglang is using nccl==2.21.5
    [2025-04-12 09:01:07 DP3 TP3] sglang is using nccl==2.21.5
    [2025-04-12 09:01:07 DP2 TP2] sglang is using nccl==2.21.5
    [2025-04-12 09:01:07 DP6 TP6] sglang is using nccl==2.21.5
    [2025-04-12 09:01:07 DP4 TP4] sglang is using nccl==2.21.5
    [2025-04-12 09:01:07 DP5 TP5] sglang is using nccl==2.21.5
    [2025-04-12 09:01:07 DP7 TP7] sglang is using nccl==2.21.5
    rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO Bootstrap : Using eth0:10.8.10.18<0>
    rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO NET/Plugin: No plugin found (libnccl-net.so)
    rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO NET/Plugin: Plugin load returned 2 : /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/nvidia/lib64/libnccl-net.so) : when loading libnccl-net.so
    rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO NET/Plugin: Using internal network plugin.
    rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO cudaDriverVersion 12080
    NCCL version 2.21.5+cuda12.4
    [2025-04-12 09:01:07 DP3 TP3] Scheduler hit an exception: Traceback (most recent call last):
    File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
    File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in init
    self.tp_worker = TpWorkerClass(
    File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in init
    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
    File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in init
    self.model_runner = ModelRunner(
    File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 177, in init
    min_per_gpu_memory = self.init_torch_distributed()
    File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 252, in init_torch_distributed
    initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
    File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 1055, in initialize_model_parallel
    _TP = init_model_parallel_group(
    File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 890, in init_model_parallel_group
    return GroupCoordinator(
    File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 233, in init
    self.pynccl_comm = PyNcclCommunicator(
    File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 108, in init
    self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
    File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 350, in ncclCommInitRank
    self.NCCL_CHECK(
    File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 329, in NCCL_CHECK
    raise RuntimeError(f"NCCL error: {error_str}")
    RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)

[2025-04-12 09:01:07 DP4 TP4] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in init
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 177, in init
min_per_gpu_memory = self.init_torch_distributed()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 252, in init_torch_distributed
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 1055, in initialize_model_parallel
_TP = init_model_parallel_group(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 890, in init_model_parallel_group
return GroupCoordinator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 233, in init
self.pynccl_comm = PyNcclCommunicator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 108, in init
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 350, in ncclCommInitRank
self.NCCL_CHECK(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 329, in NCCL_CHECK
raise RuntimeError(f"NCCL error: {error_str}")
RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)

[2025-04-12 09:01:07 DP1 TP1] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in init
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 177, in init
min_per_gpu_memory = self.init_torch_distributed()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 252, in init_torch_distributed
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 1055, in initialize_model_parallel
_TP = init_model_parallel_group(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 890, in init_model_parallel_group
return GroupCoordinator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 233, in init
self.pynccl_comm = PyNcclCommunicator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 108, in init
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 350, in ncclCommInitRank
self.NCCL_CHECK(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 329, in NCCL_CHECK
raise RuntimeError(f"NCCL error: {error_str}")
RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)

[2025-04-12 09:01:07 DP6 TP6] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in init
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 177, in init
min_per_gpu_memory = self.init_torch_distributed()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 252, in init_torch_distributed
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 1055, in initialize_model_parallel
_TP = init_model_parallel_group(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 890, in init_model_parallel_group
return GroupCoordinator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 233, in init
self.pynccl_comm = PyNcclCommunicator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 108, in init
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 350, in ncclCommInitRank
self.NCCL_CHECK(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 329, in NCCL_CHECK
raise RuntimeError(f"NCCL error: {error_str}")
RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)

[2025-04-12 09:01:07 DP5 TP5] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in init
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 177, in init
min_per_gpu_memory = self.init_torch_distributed()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 252, in init_torch_distributed
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 1055, in initialize_model_parallel
_TP = init_model_parallel_group(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 890, in init_model_parallel_group
return GroupCoordinator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 233, in init
self.pynccl_comm = PyNcclCommunicator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 108, in init
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 350, in ncclCommInitRank
self.NCCL_CHECK(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 329, in NCCL_CHECK
raise RuntimeError(f"NCCL error: {error_str}")
RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)

[2025-04-12 09:01:07 DP7 TP7] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in init
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 177, in init
min_per_gpu_memory = self.init_torch_distributed()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 252, in init_torch_distributed
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 1055, in initialize_model_parallel
_TP = init_model_parallel_group(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 890, in init_model_parallel_group
return GroupCoordinator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 233, in init
self.pynccl_comm = PyNcclCommunicator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 108, in init
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 350, in ncclCommInitRank
self.NCCL_CHECK(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 329, in NCCL_CHECK
raise RuntimeError(f"NCCL error: {error_str}")
RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)

[2025-04-12 09:01:07 DP2 TP2] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in init
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 177, in init
min_per_gpu_memory = self.init_torch_distributed()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 252, in init_torch_distributed
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 1055, in initialize_model_parallel
_TP = init_model_parallel_group(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 890, in init_model_parallel_group
return GroupCoordinator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 233, in init
self.pynccl_comm = PyNcclCommunicator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 108, in init
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 350, in ncclCommInitRank
self.NCCL_CHECK(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 329, in NCCL_CHECK
raise RuntimeError(f"NCCL error: {error_str}")
RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO cudaDriverVersion 12080
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO Bootstrap : Using eth0:10.8.10.18<0>
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO NET/Plugin: No plugin found (libnccl-net.so)
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO NET/Plugin: Plugin load returned 2 : /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/nvidia/lib64/libnccl-net.so) : when loading libnccl-net.so
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO NET/Plugin: Using internal network plugin.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO Failed to open libibverbs.so[.1]
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO NET/Socket : Using [0]eth0:10.8.10.18<0> [1]eth1:192.168.0.8<0> [2]eth2:192.168.1.8<0> [3]eth3:192.168.2.8<0> [4]eth4:192.168.3.8<0> [5]eth5:192.168.4.8<0> [6]eth6:192.168.5.8<0> [7]eth7:192.168.6.8<0> [8]eth8:192.168.7.8<0> [9]eth9:192.168.8.8<0>

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] net.cc:579 NCCL WARN Error: network gIB not found.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO init.cc:321 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO init.cc:1533 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO init.cc:1799 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:559:559 [3] NCCL INFO init.cc:1837 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO cudaDriverVersion 12080
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO Bootstrap : Using eth0:10.8.10.18<0>
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO NET/Plugin: No plugin found (libnccl-net.so)
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO NET/Plugin: Plugin load returned 2 : /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/nvidia/lib64/libnccl-net.so) : when loading libnccl-net.so
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO NET/Plugin: Using internal network plugin.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO Failed to open libibverbs.so[.1]
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO NET/Socket : Using [0]eth0:10.8.10.18<0> [1]eth1:192.168.0.8<0> [2]eth2:192.168.1.8<0> [3]eth3:192.168.2.8<0> [4]eth4:192.168.3.8<0> [5]eth5:192.168.4.8<0> [6]eth6:192.168.5.8<0> [7]eth7:192.168.6.8<0> [8]eth8:192.168.7.8<0> [9]eth9:192.168.8.8<0>

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] net.cc:579 NCCL WARN Error: network gIB not found.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO init.cc:321 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO init.cc:1533 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO init.cc:1799 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:560:560 [4] NCCL INFO init.cc:1837 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO cudaDriverVersion 12080
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO Bootstrap : Using eth0:10.8.10.18<0>
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO NET/Plugin: No plugin found (libnccl-net.so)
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO NET/Plugin: Plugin load returned 2 : /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/nvidia/lib64/libnccl-net.so) : when loading libnccl-net.so
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO NET/Plugin: Using internal network plugin.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO Failed to open libibverbs.so[.1]
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO NET/Socket : Using [0]eth0:10.8.10.18<0> [1]eth1:192.168.0.8<0> [2]eth2:192.168.1.8<0> [3]eth3:192.168.2.8<0> [4]eth4:192.168.3.8<0> [5]eth5:192.168.4.8<0> [6]eth6:192.168.5.8<0> [7]eth7:192.168.6.8<0> [8]eth8:192.168.7.8<0> [9]eth9:192.168.8.8<0>

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] net.cc:579 NCCL WARN Error: network gIB not found.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO init.cc:321 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO init.cc:1533 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO init.cc:1799 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:556:556 [1] NCCL INFO init.cc:1837 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO cudaDriverVersion 12080
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO Bootstrap : Using eth0:10.8.10.18<0>
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO NET/Plugin: No plugin found (libnccl-net.so)
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO NET/Plugin: Plugin load returned 2 : /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/nvidia/lib64/libnccl-net.so) : when loading libnccl-net.so
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO NET/Plugin: Using internal network plugin.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO Failed to open libibverbs.so[.1]
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO NET/Socket : Using [0]eth0:10.8.10.18<0> [1]eth1:192.168.0.8<0> [2]eth2:192.168.1.8<0> [3]eth3:192.168.2.8<0> [4]eth4:192.168.3.8<0> [5]eth5:192.168.4.8<0> [6]eth6:192.168.5.8<0> [7]eth7:192.168.6.8<0> [8]eth8:192.168.7.8<0> [9]eth9:192.168.8.8<0>

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] net.cc:579 NCCL WARN Error: network gIB not found.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO init.cc:321 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO init.cc:1533 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO init.cc:1799 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:563:563 [6] NCCL INFO init.cc:1837 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO cudaDriverVersion 12080
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO Bootstrap : Using eth0:10.8.10.18<0>
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO NET/Plugin: No plugin found (libnccl-net.so)
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO NET/Plugin: Plugin load returned 2 : /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/nvidia/lib64/libnccl-net.so) : when loading libnccl-net.so
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO NET/Plugin: Using internal network plugin.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO Failed to open libibverbs.so[.1]
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO NET/Socket : Using [0]eth0:10.8.10.18<0> [1]eth1:192.168.0.8<0> [2]eth2:192.168.1.8<0> [3]eth3:192.168.2.8<0> [4]eth4:192.168.3.8<0> [5]eth5:192.168.4.8<0> [6]eth6:192.168.5.8<0> [7]eth7:192.168.6.8<0> [8]eth8:192.168.7.8<0> [9]eth9:192.168.8.8<0>

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] net.cc:579 NCCL WARN Error: network gIB not found.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO init.cc:321 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO init.cc:1533 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO init.cc:1799 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:562:562 [5] NCCL INFO init.cc:1837 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO cudaDriverVersion 12080
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO Bootstrap : Using eth0:10.8.10.18<0>
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO NET/Plugin: No plugin found (libnccl-net.so)
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO NET/Plugin: Plugin load returned 2 : /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/nvidia/lib64/libnccl-net.so) : when loading libnccl-net.so
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO NET/Plugin: Using internal network plugin.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO Failed to open libibverbs.so[.1]
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO NET/Socket : Using [0]eth0:10.8.10.18<0> [1]eth1:192.168.0.8<0> [2]eth2:192.168.1.8<0> [3]eth3:192.168.2.8<0> [4]eth4:192.168.3.8<0> [5]eth5:192.168.4.8<0> [6]eth6:192.168.5.8<0> [7]eth7:192.168.6.8<0> [8]eth8:192.168.7.8<0> [9]eth9:192.168.8.8<0>

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] net.cc:579 NCCL WARN Error: network gIB not found.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO init.cc:321 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO init.cc:1533 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO init.cc:1799 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:565:565 [7] NCCL INFO init.cc:1837 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO cudaDriverVersion 12080
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO Bootstrap : Using eth0:10.8.10.18<0>
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO NET/Plugin: No plugin found (libnccl-net.so)
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO NET/Plugin: Plugin load returned 2 : /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32' not found (required by /usr/local/nvidia/lib64/libnccl-net.so) : when loading libnccl-net.so
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO NET/Plugin: Using internal network plugin.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO Failed to open libibverbs.so[.1]
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO NET/Socket : Using [0]eth0:10.8.10.18<0> [1]eth1:192.168.0.8<0> [2]eth2:192.168.1.8<0> [3]eth3:192.168.2.8<0> [4]eth4:192.168.3.8<0> [5]eth5:192.168.4.8<0> [6]eth6:192.168.5.8<0> [7]eth7:192.168.6.8<0> [8]eth8:192.168.7.8<0> [9]eth9:192.168.8.8<0>

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] net.cc:579 NCCL WARN Error: network gIB not found.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO init.cc:321 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO init.cc:1533 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO init.cc:1799 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:558:558 [2] NCCL INFO init.cc:1837 -> 5
[2025-04-12 09:01:07 DP0 TP0] Scheduler hit an exception: Traceback (most recent call last):
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1787, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 240, in init
self.tp_worker = TpWorkerClass(
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63, in init
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 68, in init
self.model_runner = ModelRunner(
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 177, in init
min_per_gpu_memory = self.init_torch_distributed()
File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 252, in init_torch_distributed
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 1055, in initialize_model_parallel
_TP = init_model_parallel_group(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 890, in init_model_parallel_group
return GroupCoordinator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/parallel_state.py", line 233, in init
self.pynccl_comm = PyNcclCommunicator(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl.py", line 108, in init
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 350, in ncclCommInitRank
self.NCCL_CHECK(
File "/sgl-workspace/sglang/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py", line 329, in NCCL_CHECK
raise RuntimeError(f"NCCL error: {error_str}")
RuntimeError: NCCL error: invalid usage (run with NCCL_DEBUG=WARN for details)

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO Failed to open libibverbs.so[.1]
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO NET/Socket : Using [0]eth0:10.8.10.18<0> [1]eth1:192.168.0.8<0> [2]eth2:192.168.1.8<0> [3]eth3:192.168.2.8<0> [4]eth4:192.168.3.8<0> [5]eth5:192.168.4.8<0> [6]eth6:192.168.5.8<0> [7]eth7:192.168.6.8<0> [8]eth8:192.168.7.8<0> [9]eth9:192.168.8.8<0>

rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] net.cc:579 NCCL WARN Error: network gIB not found.
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO init.cc:321 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO init.cc:1533 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO init.cc:1799 -> 5
rickruguichen-serving-deepseek-r1-model-serving-5d74c6754cq79gp:554:554 [0] NCCL INFO init.cc:1837 -> 5
[rank0]:[W412 09:01:08.424402267 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch

@Abhishekbhagwat
Copy link
Collaborator

Thanks @rick-c-goog, will send a fix soon for (1) and (2)
For (3) could I check if you used Cluster toolkit to provision your cluster ? (as per https://github.com/AI-Hypercomputer/gpu-recipes/blob/main/docs/configuring-environment-gke-a3-ultra.md)

@Abhishekbhagwat Abhishekbhagwat self-assigned this May 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants