|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import os |
3 | 4 | import sys
|
| 5 | +import warnings |
4 | 6 | from dataclasses import dataclass, field
|
5 |
| -from functools import cache, partial |
| 7 | +from functools import cache, partial, wraps |
6 | 8 | from importlib.util import find_spec
|
7 | 9 | from pathlib import Path
|
8 |
| -from typing import TYPE_CHECKING |
| 10 | +from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload |
9 | 11 |
|
10 | 12 | from packaging.version import Version
|
11 | 13 |
|
12 | 14 | if TYPE_CHECKING:
|
| 15 | + from collections.abc import Callable |
13 | 16 | from importlib.metadata import PackageMetadata
|
14 | 17 |
|
| 18 | +P = ParamSpec("P") |
| 19 | +R = TypeVar("R") |
| 20 | + |
15 | 21 |
|
16 | 22 | if TYPE_CHECKING:
|
17 | 23 | # type checkers are confused and can only see …core.Array
|
@@ -90,3 +96,101 @@ def pkg_version(package: str) -> Version:
|
90 | 96 | # but this code makes it possible to run scanpy without it.
|
91 | 97 | def old_positionals(*old_positionals: str):
|
92 | 98 | return lambda func: func
|
| 99 | + |
| 100 | + |
| 101 | +@overload |
| 102 | +def njit(fn: Callable[P, R], /) -> Callable[P, R]: ... |
| 103 | +@overload |
| 104 | +def njit() -> Callable[[Callable[P, R]], Callable[P, R]]: ... |
| 105 | +def njit( |
| 106 | + fn: Callable[P, R] | None = None, / |
| 107 | +) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: |
| 108 | + """\ |
| 109 | + Jit-compile a function using numba. |
| 110 | +
|
| 111 | + On call, this function dispatches to a parallel or sequential numba function, |
| 112 | + depending on if it has been called from a thread pool. |
| 113 | +
|
| 114 | + See <https://github.com/numbagg/numbagg/pull/201/files#r1409374809> |
| 115 | + """ |
| 116 | + |
| 117 | + def decorator(f: Callable[P, R], /) -> Callable[P, R]: |
| 118 | + import numba |
| 119 | + |
| 120 | + fns: dict[bool, Callable[P, R]] = { |
| 121 | + parallel: numba.njit(f, cache=True, parallel=parallel) # noqa: TID251 |
| 122 | + for parallel in (True, False) |
| 123 | + } |
| 124 | + |
| 125 | + @wraps(f) |
| 126 | + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
| 127 | + parallel = not _is_in_unsafe_thread_pool() |
| 128 | + if not parallel: |
| 129 | + msg = ( |
| 130 | + "Detected unsupported threading environment. " |
| 131 | + f"Trying to run {f.__name__} in serial mode. " |
| 132 | + "In case of problems, install `tbb`." |
| 133 | + ) |
| 134 | + warnings.warn(msg, stacklevel=2) |
| 135 | + return fns[parallel](*args, **kwargs) |
| 136 | + |
| 137 | + return wrapper |
| 138 | + |
| 139 | + return decorator if fn is None else decorator(fn) |
| 140 | + |
| 141 | + |
| 142 | +LayerType = Literal["default", "safe", "threadsafe", "forksafe"] |
| 143 | +Layer = Literal["tbb", "omp", "workqueue"] |
| 144 | + |
| 145 | + |
| 146 | +LAYERS: dict[LayerType, set[Layer]] = { |
| 147 | + "default": {"tbb", "omp", "workqueue"}, |
| 148 | + "safe": {"tbb"}, |
| 149 | + "threadsafe": {"tbb", "omp"}, |
| 150 | + "forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})}, |
| 151 | +} |
| 152 | + |
| 153 | + |
| 154 | +def _is_in_unsafe_thread_pool() -> bool: |
| 155 | + import threading |
| 156 | + |
| 157 | + current_thread = threading.current_thread() |
| 158 | + # ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1' |
| 159 | + return ( |
| 160 | + current_thread.name.startswith("ThreadPoolExecutor") |
| 161 | + and _numba_threading_layer() not in LAYERS["threadsafe"] |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +@cache |
| 166 | +def _numba_threading_layer() -> Layer: |
| 167 | + """\ |
| 168 | + Get numba’s threading layer. |
| 169 | +
|
| 170 | + This function implements the algorithm as described in |
| 171 | + <https://numba.readthedocs.io/en/stable/user/threading-layer.html> |
| 172 | + """ |
| 173 | + import importlib |
| 174 | + |
| 175 | + import numba |
| 176 | + |
| 177 | + if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None: |
| 178 | + # given by direct name |
| 179 | + return numba.config.THREADING_LAYER |
| 180 | + |
| 181 | + # given by layer type (safe, …) |
| 182 | + for layer in cast(list[Layer], numba.config.THREADING_LAYER_PRIORITY): |
| 183 | + if layer not in available: |
| 184 | + continue |
| 185 | + if layer != "workqueue": |
| 186 | + try: # `importlib.util.find_spec` doesn’t work here |
| 187 | + importlib.import_module(f"numba.np.ufunc.{layer}pool") |
| 188 | + except ImportError: |
| 189 | + continue |
| 190 | + # the layer has been found |
| 191 | + return layer |
| 192 | + msg = ( |
| 193 | + f"No loadable threading layer: {numba.config.THREADING_LAYER=} " |
| 194 | + f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})" |
| 195 | + ) |
| 196 | + raise ValueError(msg) |
0 commit comments