Skip to content

feat: detect CPUs and configure threading sensibly #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# many cases, so it should only be enabled when prompt_logprobs are required
# for experimentation purposes.
VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS: bool = False
VLLM_SPYRE_UPDATE_THREAD_CONFIG: bool = True

logger = init_logger(__name__)

Expand Down Expand Up @@ -94,6 +95,11 @@ def _backend_backwards_compat() -> str:
# By default, prompt_logprobs aren't supported
"VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS":
lambda: bool(int(os.getenv("VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS", "0"))),

# Allow vllm-spyre to update env vars related to multi-threading (eg. OMP)
# based on the detected CPU cores and server configuration
"VLLM_SPYRE_UPDATE_THREAD_CONFIG":
lambda: bool(int(os.getenv("VLLM_SPYRE_UPDATE_THREAD_CONFIG", "1"))),
}
# --8<-- [end:env-vars-definition]

Expand Down
113 changes: 113 additions & 0 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import sys

# When running this plugin on a Mac, we assume it's for local development
Expand Down Expand Up @@ -83,6 +84,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
f'vllm_spyre{".v1" if envs.VLLM_USE_V1 else ""}'\
'.worker.spyre_worker.SpyreWorker')

cls._check_threading_config(parallel_config.world_size)

if envs_spyre.VLLM_SPYRE_USE_CB and is_decoder:
scheduler_config.scheduler_cls = "vllm_spyre.v1.core."\
"scheduler.ContinuousBatchingSpyreScheduler"
Expand Down Expand Up @@ -309,3 +312,113 @@ def _get_matching_warmup_shapes(
if prompt_len <= shape['prompt_length']
and max_tokens <= shape['new_tokens']
]

@classmethod
def _check_threading_config(cls, worker_count: int):
"""
Check parallelism configuration to avoid CPU contention

Libraries that support multi-threading (eg. OpenMP) default to
parallelism based on the number of CPUs on the host. This can lead to
CPU contention in containerized deployments especially when process
forking is involved. This function provides better default behavior.
"""

# The quay.io/ibm-aiu/spyre-base image includes shell scripts that
# automatically set OMP_NUM_THREADS to the result of `nproc --all`.
#
# vLLM also already has logic around threading to be aware of,
# - sets TORCHINDUCTOR_COMPILE_THREADS=1 (https://github.com/vllm-project/vllm/blob/baba0389f7e810a361fff5229ce20c2d5a2b1fac/vllm/env_override.py#L38-L39)
# - it will set OMP_NUM_THREADS=1 when using multiple workers (https://github.com/vllm-project/vllm/blob/baba0389f7e810a361fff5229ce20c2d5a2b1fac/vllm/executor/multiproc_worker_utils.py#L304)
# - has configurations for OMP thread binding (https://github.com/vllm-project/vllm/blob/baba0389f7e810a361fff5229ce20c2d5a2b1fac/vllm/envs.py#L435-L438)
# - the bind attempts to detect NUMA nodes (https://github.com/vllm-project/vllm/blob/baba0389f7e810a361fff5229ce20c2d5a2b1fac/vllm/v1/worker/cpu_worker.py#L111)
Comment on lines +331 to +334
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pointers


# Always print current env for awareness
threading_envs = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on setting these as "constants" in this file?

"OMP_NUM_THREADS",
# "TORCHINDUCTOR_COMPILE_THREADS", # vLLM wants this set to 1
"DT_PARALLEL_THREADS", # affects the compilation during warmup
# set these for good measure
"OPENBLAS_NUM_THREADS",
"MKL_NUM_THREADS",
]
env_map = {env: os.getenv(env) for env in threading_envs}
logger.info(
"Initial threading configurations: %s",
' '.join([f"{env}={value}" for env, value in env_map.items()]))

# Try to determine the CPU time/cores that we are allocated
cpu_count: Optional[float] = None
detection_message = ""
try:
# try to query cgroup CPU limits
with open('/sys/fs/cgroup/cpu.max') as f:
quota_str, period_str = f.read().strip().split()

if quota_str != 'max':
quota = int(quota_str)
period = int(period_str)
cpu_count = float(quota) / period
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In python we don't need to cast to float in order to do a floating point division

Suggested change
cpu_count = float(quota) / period
cpu_count = quota / period

detection_message = f"Detected cgroup CPU limit of {cpu_count}"

except FileNotFoundError:
# file may not exist if not running under cgroups v2
pass
except Exception as e:
logger.debug(
"Error parsing /sys/fs/cgroup/cpu.max to get CPU info",
exc_info=e)

# could try `nproc` here, but it is affected by
# OMP_NUM_THREADS itself

# try os.cpu_count() to get node CPU count
if cpu_count is None and (cpu_count_res := os.cpu_count()) is not None:
cpu_count = float(cpu_count_res)
detection_message = \
f"Detected {cpu_count} CPUs from `os.cpu_count()`"

cpus_per_worker = math.ceil(
cpu_count / worker_count) if cpu_count is not None else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if worker_count will always contain a valid value. Can you add an assert for worker_count to be >0 in the beginning of the function


thread_warning = "Excessive threads may result in CPU contention. " \
+ "Note that each worker processes has its own thread pools." \
if worker_count > 1 else ""
failed_detection_message = "Unable to detect available CPUs to " \
"validate threading configuration."

if envs_spyre.VLLM_SPYRE_UPDATE_THREAD_CONFIG:
if cpus_per_worker is None:
raise RuntimeError(
f"{failed_detection_message} Use "
"VLLM_SPYRE_UPDATE_THREAD_CONFIG=0 and configure manually."
)
Comment on lines +392 to +395
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo it would be better to have a warning message and keep the original threading values instead (ie. not changing anything), since you set VLLM_SPYRE_UPDATE_THREAD_CONFIG to 1 by default and there is the possibility that both the parsing of cpu.max fails and os.cpu_count() returns None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point since I made it on by default. I'll change it to just log instead of raising.


for env in threading_envs:
os.environ[env] = str(cpus_per_worker)

logger.info(
"%s for %d workers. Since VLLM_SPYRE_UPDATE_THREAD_CONFIG is "
"enabled, setting threading configurations to %d",
detection_message, worker_count, cpus_per_worker)
return

# In the case that VLLM_SPYRE_UPDATE_THREAD_CONFIG is not enabled,
# check configs and maybe log a warning
if cpus_per_worker is None:
logger.info("%s %s", failed_detection_message, thread_warning)
return

def _float_or_0(s: str) -> float:
try:
return float(s)
except ValueError:
return 0.0

if any((value is None or _float_or_0(value) > 1.2 * cpus_per_worker)
for value in env_map.values()):
logger.warning(
"%s %s for %d workers. Recommend setting each threading "
"configuration to %d. Set VLLM_SPYRE_UPDATE_THREAD_CONFIG=1 "
"to do this automatically.", thread_warning, detection_message,
worker_count, cpus_per_worker)