Skip to content

Commit d0adc25

Browse files
authored
move all njit calls into a decorator (#3335)
1 parent 9d3c340 commit d0adc25

File tree

12 files changed

+150
-37
lines changed

12 files changed

+150
-37
lines changed

docs/release-notes/3335.feature.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Run numba functions single-threaded when called from inside of a ThreadPool {smaller}`P Angerer`

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@ required-imports = ["from __future__ import annotations"]
262262
"pandas.value_counts".msg = "Use pd.Series(a).value_counts() instead"
263263
"legacy_api_wrap.legacy_api".msg = "Use scanpy._compat.old_positionals instead"
264264
"numpy.bool".msg = "Use `np.bool_` instead for numpy>=1.24<2 compatibility"
265+
"numba.jit".msg = "Use `scanpy._compat.njit` instead"
266+
"numba.njit".msg = "Use `scanpy._compat.njit` instead"
265267
[tool.ruff.lint.flake8-type-checking]
266268
exempt-modules = []
267269
strict = true

src/scanpy/_compat.py

+106-2
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from __future__ import annotations
22

3+
import os
34
import sys
5+
import warnings
46
from dataclasses import dataclass, field
5-
from functools import cache, partial
7+
from functools import cache, partial, wraps
68
from importlib.util import find_spec
79
from pathlib import Path
8-
from typing import TYPE_CHECKING
10+
from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload
911

1012
from packaging.version import Version
1113

1214
if TYPE_CHECKING:
15+
from collections.abc import Callable
1316
from importlib.metadata import PackageMetadata
1417

18+
P = ParamSpec("P")
19+
R = TypeVar("R")
20+
1521

1622
if TYPE_CHECKING:
1723
# type checkers are confused and can only see …core.Array
@@ -90,3 +96,101 @@ def pkg_version(package: str) -> Version:
9096
# but this code makes it possible to run scanpy without it.
9197
def old_positionals(*old_positionals: str):
9298
return lambda func: func
99+
100+
101+
@overload
102+
def njit(fn: Callable[P, R], /) -> Callable[P, R]: ...
103+
@overload
104+
def njit() -> Callable[[Callable[P, R]], Callable[P, R]]: ...
105+
def njit(
106+
fn: Callable[P, R] | None = None, /
107+
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
108+
"""\
109+
Jit-compile a function using numba.
110+
111+
On call, this function dispatches to a parallel or sequential numba function,
112+
depending on if it has been called from a thread pool.
113+
114+
See <https://github.com/numbagg/numbagg/pull/201/files#r1409374809>
115+
"""
116+
117+
def decorator(f: Callable[P, R], /) -> Callable[P, R]:
118+
import numba
119+
120+
fns: dict[bool, Callable[P, R]] = {
121+
parallel: numba.njit(f, cache=True, parallel=parallel) # noqa: TID251
122+
for parallel in (True, False)
123+
}
124+
125+
@wraps(f)
126+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
127+
parallel = not _is_in_unsafe_thread_pool()
128+
if not parallel:
129+
msg = (
130+
"Detected unsupported threading environment. "
131+
f"Trying to run {f.__name__} in serial mode. "
132+
"In case of problems, install `tbb`."
133+
)
134+
warnings.warn(msg, stacklevel=2)
135+
return fns[parallel](*args, **kwargs)
136+
137+
return wrapper
138+
139+
return decorator if fn is None else decorator(fn)
140+
141+
142+
LayerType = Literal["default", "safe", "threadsafe", "forksafe"]
143+
Layer = Literal["tbb", "omp", "workqueue"]
144+
145+
146+
LAYERS: dict[LayerType, set[Layer]] = {
147+
"default": {"tbb", "omp", "workqueue"},
148+
"safe": {"tbb"},
149+
"threadsafe": {"tbb", "omp"},
150+
"forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})},
151+
}
152+
153+
154+
def _is_in_unsafe_thread_pool() -> bool:
155+
import threading
156+
157+
current_thread = threading.current_thread()
158+
# ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1'
159+
return (
160+
current_thread.name.startswith("ThreadPoolExecutor")
161+
and _numba_threading_layer() not in LAYERS["threadsafe"]
162+
)
163+
164+
165+
@cache
166+
def _numba_threading_layer() -> Layer:
167+
"""\
168+
Get numba’s threading layer.
169+
170+
This function implements the algorithm as described in
171+
<https://numba.readthedocs.io/en/stable/user/threading-layer.html>
172+
"""
173+
import importlib
174+
175+
import numba
176+
177+
if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None:
178+
# given by direct name
179+
return numba.config.THREADING_LAYER
180+
181+
# given by layer type (safe, …)
182+
for layer in cast(list[Layer], numba.config.THREADING_LAYER_PRIORITY):
183+
if layer not in available:
184+
continue
185+
if layer != "workqueue":
186+
try: # `importlib.util.find_spec` doesn’t work here
187+
importlib.import_module(f"numba.np.ufunc.{layer}pool")
188+
except ImportError:
189+
continue
190+
# the layer has been found
191+
return layer
192+
msg = (
193+
f"No loadable threading layer: {numba.config.THREADING_LAYER=} "
194+
f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})"
195+
)
196+
raise ValueError(msg)

src/scanpy/_utils/compute/is_constant.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from numbers import Integral
66
from typing import TYPE_CHECKING, TypeVar, overload
77

8+
import numba
89
import numpy as np
9-
from numba import njit
1010
from scipy import sparse
1111

12-
from ..._compat import DaskArray
12+
from ..._compat import DaskArray, njit
1313

1414
if TYPE_CHECKING:
1515
from typing import Literal
@@ -103,22 +103,21 @@ def _(
103103
else:
104104
return (a.data == 0).all()
105105
if axis == 1:
106-
return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape)
106+
return _is_constant_csr_rows(a.data, a.indptr, a.shape)
107107
elif axis == 0:
108108
a = a.T.tocsr()
109-
return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape)
109+
return _is_constant_csr_rows(a.data, a.indptr, a.shape)
110110

111111

112112
@njit
113113
def _is_constant_csr_rows(
114114
data: NDArray[np.number],
115-
indices: NDArray[np.integer],
116115
indptr: NDArray[np.integer],
117116
shape: tuple[int, int],
118-
):
117+
) -> NDArray[np.bool_]:
119118
n = len(indptr) - 1
120119
result = np.ones(n, dtype=np.bool_)
121-
for i in range(n):
120+
for i in numba.prange(n):
122121
start = indptr[i]
123122
stop = indptr[i + 1]
124123
val = data[start] if stop - start == shape[1] else 0
@@ -139,10 +138,10 @@ def _(
139138
else:
140139
return (a.data == 0).all()
141140
if axis == 0:
142-
return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape[::-1])
141+
return _is_constant_csr_rows(a.data, a.indptr, a.shape[::-1])
143142
elif axis == 1:
144143
a = a.T.tocsc()
145-
return _is_constant_csr_rows(a.data, a.indices, a.indptr, a.shape[::-1])
144+
return _is_constant_csr_rows(a.data, a.indptr, a.shape[::-1])
146145

147146

148147
@is_constant.register(DaskArray)
@@ -151,4 +150,8 @@ def _(a: DaskArray, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool
151150
v = a[tuple(0 for _ in range(a.ndim))].compute()
152151
return (a == v).all()
153152
# TODO: use overlapping blocks and reduction instead of `drop_axis`
154-
return a.map_blocks(partial(is_constant, axis=axis), drop_axis=axis)
153+
return a.map_blocks(
154+
partial(is_constant, axis=axis),
155+
drop_axis=axis,
156+
meta=np.array([], dtype=a.dtype),
157+
)

src/scanpy/experimental/pp/_highly_variable_genes.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from anndata import AnnData
1313

1414
from scanpy import logging as logg
15+
from scanpy._compat import njit
1516
from scanpy._settings import Verbosity, settings
1617
from scanpy._utils import _doc_params, check_nonnegative_integers, view_to_actual
1718
from scanpy.experimental._docs import (
@@ -32,7 +33,7 @@
3233
from numpy.typing import NDArray
3334

3435

35-
@nb.njit(parallel=True)
36+
@njit
3637
def _calculate_res_sparse(
3738
indptr: NDArray[np.integer],
3839
index: NDArray[np.integer],
@@ -92,7 +93,7 @@ def clac_clipped_res_sparse(gene: int, cell: int, value: np.float64) -> np.float
9293
return residuals
9394

9495

95-
@nb.njit(parallel=True)
96+
@njit
9697
def _calculate_res_dense(
9798
matrix,
9899
*,

src/scanpy/metrics/_gearys_c.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from scipy import sparse
1111

12-
from .._compat import fullname
12+
from .._compat import fullname, njit
1313
from ..get import _get_obs_rep
1414
from ._common import _check_vals, _resolve_vals
1515

@@ -136,7 +136,6 @@ def gearys_c(
136136
# tests to fail.
137137

138138

139-
@numba.njit(cache=True, parallel=True)
140139
def _gearys_c_vec(
141140
data: np.ndarray,
142141
indices: np.ndarray,
@@ -147,7 +146,7 @@ def _gearys_c_vec(
147146
return _gearys_c_vec_W(data, indices, indptr, x, W)
148147

149148

150-
@numba.njit(cache=True, parallel=True)
149+
@njit
151150
def _gearys_c_vec_W(
152151
data: np.ndarray,
153152
indices: np.ndarray,
@@ -182,7 +181,7 @@ def _gearys_c_vec_W(
182181
# https://github.com/numba/numba/issues/6774#issuecomment-788789663
183182

184183

185-
@numba.njit(cache=True)
184+
@numba.njit(cache=True, parallel=False) # noqa: TID251
186185
def _gearys_c_inner_sparse_x_densevec(
187186
g_data: np.ndarray,
188187
g_indices: np.ndarray,
@@ -203,7 +202,7 @@ def _gearys_c_inner_sparse_x_densevec(
203202
return numer / denom
204203

205204

206-
@numba.njit(cache=True)
205+
@numba.njit(cache=True, parallel=False) # noqa: TID251
207206
def _gearys_c_inner_sparse_x_sparsevec( # noqa: PLR0917
208207
g_data: np.ndarray,
209208
g_indices: np.ndarray,
@@ -239,7 +238,7 @@ def _gearys_c_inner_sparse_x_sparsevec( # noqa: PLR0917
239238
return numer / denom
240239

241240

242-
@numba.njit(cache=True, parallel=True)
241+
@njit
243242
def _gearys_c_mtx(
244243
g_data: np.ndarray,
245244
g_indices: np.ndarray,
@@ -256,7 +255,7 @@ def _gearys_c_mtx(
256255
return out
257256

258257

259-
@numba.njit(cache=True, parallel=True)
258+
@njit
260259
def _gearys_c_mtx_csr( # noqa: PLR0917
261260
g_data: np.ndarray,
262261
g_indices: np.ndarray,

src/scanpy/metrics/_morans_i.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from scipy import sparse
1111

12-
from .._compat import fullname
12+
from .._compat import fullname, njit
1313
from ..get import _get_obs_rep
1414
from ._common import _check_vals, _resolve_vals
1515

@@ -126,7 +126,7 @@ def morans_i(
126126
# This is done in a very similar way to gearys_c. See notes there for details.
127127

128128

129-
@numba.njit(cache=True, parallel=True)
129+
@njit
130130
def _morans_i_vec(
131131
g_data: np.ndarray,
132132
g_indices: np.ndarray,
@@ -137,7 +137,7 @@ def _morans_i_vec(
137137
return _morans_i_vec_W(g_data, g_indices, g_indptr, x, W)
138138

139139

140-
@numba.njit(cache=True)
140+
@numba.njit(cache=True, parallel=False) # noqa: TID251
141141
def _morans_i_vec_W(
142142
g_data: np.ndarray,
143143
g_indices: np.ndarray,
@@ -159,7 +159,7 @@ def _morans_i_vec_W(
159159
return len(x) / W * inum / z2ss
160160

161161

162-
@numba.njit(cache=True)
162+
@numba.njit(cache=True, parallel=False) # noqa: TID251
163163
def _morans_i_vec_W_sparse( # noqa: PLR0917
164164
g_data: np.ndarray,
165165
g_indices: np.ndarray,
@@ -174,7 +174,7 @@ def _morans_i_vec_W_sparse( # noqa: PLR0917
174174
return _morans_i_vec_W(g_data, g_indices, g_indptr, x, W)
175175

176176

177-
@numba.njit(cache=True, parallel=True)
177+
@njit
178178
def _morans_i_mtx(
179179
g_data: np.ndarray,
180180
g_indices: np.ndarray,
@@ -191,7 +191,7 @@ def _morans_i_mtx(
191191
return out
192192

193193

194-
@numba.njit(cache=True, parallel=True)
194+
@njit
195195
def _morans_i_mtx_csr( # noqa: PLR0917
196196
g_data: np.ndarray,
197197
g_indices: np.ndarray,

src/scanpy/preprocessing/_highly_variable_genes.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ def _highly_variable_genes_seurat_v3(
200200
return df
201201

202202

203-
@numba.njit(cache=True)
203+
# parallel=False needed for accuracy
204+
@numba.njit(cache=True, parallel=False) # noqa: TID251
204205
def _sum_and_sum_squares_clipped(
205206
indices: NDArray[np.integer],
206207
data: NDArray[np.floating],
@@ -211,7 +212,7 @@ def _sum_and_sum_squares_clipped(
211212
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
212213
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
213214
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
214-
for i in range(nnz):
215+
for i in numba.prange(nnz):
215216
idx = indices[i]
216217
element = min(np.float64(data[i]), clip_val[idx])
217218
squared_batch_counts_sum[idx] += element**2

src/scanpy/preprocessing/_qc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from scanpy.preprocessing._distributed import materialize_as_ndarray
1313
from scanpy.preprocessing._utils import _get_mean_var
1414

15-
from .._compat import DaskArray
15+
from .._compat import DaskArray, njit
1616
from .._utils import _doc_params, axis_nnz, axis_sum
1717
from ._docs import (
1818
doc_adata_basic,
@@ -445,7 +445,7 @@ def _(mtx: spmatrix, ns: Collection[int]) -> DaskArray:
445445
return top_segment_proportions_sparse_csr(mtx.data, mtx.indptr, np.array(ns))
446446

447447

448-
@numba.njit(cache=True, parallel=True)
448+
@njit
449449
def top_segment_proportions_sparse_csr(data, indptr, ns):
450450
# work around https://github.com/numba/numba/issues/5056
451451
indptr = indptr.astype(np.int64)

0 commit comments

Comments
 (0)