-
Notifications
You must be signed in to change notification settings - Fork 18
feat: detect CPUs and configure threading sensibly #291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,3 +1,4 @@ | ||||||
import math | ||||||
import sys | ||||||
|
||||||
# When running this plugin on a Mac, we assume it's for local development | ||||||
|
@@ -83,6 +84,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |||||
f'vllm_spyre{".v1" if envs.VLLM_USE_V1 else ""}'\ | ||||||
'.worker.spyre_worker.SpyreWorker') | ||||||
|
||||||
cls._check_threading_config(parallel_config.world_size) | ||||||
|
||||||
if envs_spyre.VLLM_SPYRE_USE_CB and is_decoder: | ||||||
scheduler_config.scheduler_cls = "vllm_spyre.v1.core."\ | ||||||
"scheduler.ContinuousBatchingSpyreScheduler" | ||||||
|
@@ -309,3 +312,113 @@ def _get_matching_warmup_shapes( | |||||
if prompt_len <= shape['prompt_length'] | ||||||
and max_tokens <= shape['new_tokens'] | ||||||
] | ||||||
|
||||||
@classmethod | ||||||
def _check_threading_config(cls, worker_count: int): | ||||||
""" | ||||||
Check parallelism configuration to avoid CPU contention | ||||||
|
||||||
Libraries that support multi-threading (eg. OpenMP) default to | ||||||
parallelism based on the number of CPUs on the host. This can lead to | ||||||
CPU contention in containerized deployments especially when process | ||||||
forking is involved. This function provides better default behavior. | ||||||
""" | ||||||
|
||||||
# The quay.io/ibm-aiu/spyre-base image includes shell scripts that | ||||||
# automatically set OMP_NUM_THREADS to the result of `nproc --all`. | ||||||
# | ||||||
# vLLM also already has logic around threading to be aware of, | ||||||
# - sets TORCHINDUCTOR_COMPILE_THREADS=1 (https://github.com/vllm-project/vllm/blob/baba0389f7e810a361fff5229ce20c2d5a2b1fac/vllm/env_override.py#L38-L39) | ||||||
# - it will set OMP_NUM_THREADS=1 when using multiple workers (https://github.com/vllm-project/vllm/blob/baba0389f7e810a361fff5229ce20c2d5a2b1fac/vllm/executor/multiproc_worker_utils.py#L304) | ||||||
# - has configurations for OMP thread binding (https://github.com/vllm-project/vllm/blob/baba0389f7e810a361fff5229ce20c2d5a2b1fac/vllm/envs.py#L435-L438) | ||||||
# - the bind attempts to detect NUMA nodes (https://github.com/vllm-project/vllm/blob/baba0389f7e810a361fff5229ce20c2d5a2b1fac/vllm/v1/worker/cpu_worker.py#L111) | ||||||
|
||||||
# Always print current env for awareness | ||||||
threading_envs = [ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thoughts on setting these as "constants" in this file? |
||||||
"OMP_NUM_THREADS", | ||||||
# "TORCHINDUCTOR_COMPILE_THREADS", # vLLM wants this set to 1 | ||||||
"DT_PARALLEL_THREADS", # affects the compilation during warmup | ||||||
# set these for good measure | ||||||
"OPENBLAS_NUM_THREADS", | ||||||
"MKL_NUM_THREADS", | ||||||
] | ||||||
env_map = {env: os.getenv(env) for env in threading_envs} | ||||||
logger.info( | ||||||
"Initial threading configurations: %s", | ||||||
' '.join([f"{env}={value}" for env, value in env_map.items()])) | ||||||
|
||||||
# Try to determine the CPU time/cores that we are allocated | ||||||
cpu_count: Optional[float] = None | ||||||
detection_message = "" | ||||||
try: | ||||||
# try to query cgroup CPU limits | ||||||
with open('/sys/fs/cgroup/cpu.max') as f: | ||||||
quota_str, period_str = f.read().strip().split() | ||||||
|
||||||
if quota_str != 'max': | ||||||
quota = int(quota_str) | ||||||
period = int(period_str) | ||||||
cpu_count = float(quota) / period | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In python we don't need to cast to float in order to do a floating point division
Suggested change
|
||||||
detection_message = f"Detected cgroup CPU limit of {cpu_count}" | ||||||
|
||||||
except FileNotFoundError: | ||||||
# file may not exist if not running under cgroups v2 | ||||||
pass | ||||||
except Exception as e: | ||||||
logger.debug( | ||||||
"Error parsing /sys/fs/cgroup/cpu.max to get CPU info", | ||||||
exc_info=e) | ||||||
|
||||||
# could try `nproc` here, but it is affected by | ||||||
# OMP_NUM_THREADS itself | ||||||
|
||||||
# try os.cpu_count() to get node CPU count | ||||||
if cpu_count is None and (cpu_count_res := os.cpu_count()) is not None: | ||||||
cpu_count = float(cpu_count_res) | ||||||
detection_message = \ | ||||||
f"Detected {cpu_count} CPUs from `os.cpu_count()`" | ||||||
|
||||||
cpus_per_worker = math.ceil( | ||||||
cpu_count / worker_count) if cpu_count is not None else None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if |
||||||
|
||||||
thread_warning = "Excessive threads may result in CPU contention. " \ | ||||||
+ "Note that each worker processes has its own thread pools." \ | ||||||
if worker_count > 1 else "" | ||||||
failed_detection_message = "Unable to detect available CPUs to " \ | ||||||
"validate threading configuration." | ||||||
|
||||||
if envs_spyre.VLLM_SPYRE_UPDATE_THREAD_CONFIG: | ||||||
if cpus_per_worker is None: | ||||||
raise RuntimeError( | ||||||
f"{failed_detection_message} Use " | ||||||
"VLLM_SPYRE_UPDATE_THREAD_CONFIG=0 and configure manually." | ||||||
) | ||||||
Comment on lines
+392
to
+395
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imo it would be better to have a warning message and keep the original threading values instead (ie. not changing anything), since you set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point since I made it on by default. I'll change it to just log instead of raising. |
||||||
|
||||||
for env in threading_envs: | ||||||
os.environ[env] = str(cpus_per_worker) | ||||||
|
||||||
logger.info( | ||||||
"%s for %d workers. Since VLLM_SPYRE_UPDATE_THREAD_CONFIG is " | ||||||
"enabled, setting threading configurations to %d", | ||||||
detection_message, worker_count, cpus_per_worker) | ||||||
return | ||||||
|
||||||
# In the case that VLLM_SPYRE_UPDATE_THREAD_CONFIG is not enabled, | ||||||
# check configs and maybe log a warning | ||||||
if cpus_per_worker is None: | ||||||
logger.info("%s %s", failed_detection_message, thread_warning) | ||||||
return | ||||||
|
||||||
def _float_or_0(s: str) -> float: | ||||||
try: | ||||||
return float(s) | ||||||
except ValueError: | ||||||
return 0.0 | ||||||
|
||||||
if any((value is None or _float_or_0(value) > 1.2 * cpus_per_worker) | ||||||
for value in env_map.values()): | ||||||
logger.warning( | ||||||
"%s %s for %d workers. Recommend setting each threading " | ||||||
"configuration to %d. Set VLLM_SPYRE_UPDATE_THREAD_CONFIG=1 " | ||||||
"to do this automatically.", thread_warning, detection_message, | ||||||
worker_count, cpus_per_worker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the pointers