Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move all njit calls into a decorator #3335

Merged
merged 17 commits into from
Nov 8, 2024
Merged

move all njit calls into a decorator #3335

merged 17 commits into from
Nov 8, 2024

Conversation

flying-sheep
Copy link
Member

@flying-sheep flying-sheep commented Oct 31, 2024

@ilan-gold said

I found numba/numba#9288 which would suggest that the issue really is threading (they claim it from guvectorize but it would seem we are seeing the same thing). It seems like there are a few things for us to try from there, especially if you really saw that using dask in processing mode is not the fix (strange but ok). The issue appears to be mac specific as well, so a bit less critical.
See: numbagg/numbagg#201
I can try some of these out

Copy link

codecov bot commented Oct 31, 2024

Codecov Report

Attention: Patch coverage is 90.54054% with 7 lines in your changes missing coverage. Please review.

Project coverage is 76.59%. Comparing base (9d3c340) to head (e8c5d14).
Report is 43 commits behind head on main.

Files with missing lines Patch % Lines
src/scanpy/_compat.py 87.23% 6 Missing ⚠️
src/scanpy/_utils/compute/is_constant.py 87.50% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3335      +/-   ##
==========================================
- Coverage   77.26%   76.59%   -0.68%     
==========================================
  Files         111      111              
  Lines       12630    12827     +197     
==========================================
+ Hits         9759     9825      +66     
- Misses       2871     3002     +131     
Files with missing lines Coverage Δ
...c/scanpy/experimental/pp/_highly_variable_genes.py 63.69% <100.00%> (+0.23%) ⬆️
src/scanpy/metrics/_gearys_c.py 57.50% <100.00%> (-28.55%) ⬇️
src/scanpy/metrics/_morans_i.py 64.17% <100.00%> (-21.87%) ⬇️
src/scanpy/preprocessing/_highly_variable_genes.py 95.23% <ø> (ø)
src/scanpy/preprocessing/_qc.py 82.22% <100.00%> (-11.98%) ⬇️
src/scanpy/preprocessing/_scale.py 79.10% <100.00%> (-12.79%) ⬇️
src/scanpy/preprocessing/_simple.py 90.56% <ø> (ø)
src/scanpy/preprocessing/_utils.py 53.08% <100.00%> (-44.35%) ⬇️
src/scanpy/_utils/compute/is_constant.py 80.00% <87.50%> (ø)
src/scanpy/_compat.py 81.72% <87.23%> (+4.63%) ⬆️

Copy link

scverse-benchmark bot commented Oct 31, 2024

Benchmark changes

Change Before [9d3c340] After [e8c5d14] Ratio Benchmark (Parameter)
- 404M 367M 0.91 preprocessing_counts.peakmem_scrublet('pbmc68k_reduced', 'counts-off-axis')
+ 2.41±0.1ms 3.49±0.1ms 1.45 preprocessing_log.FastSuite.time_mean_var('pbmc3k', None)
+ 321M 367M 1.14 preprocessing_log.peakmem_pca('pbmc68k_reduced', 'off-axis')
+ 22.2±0.2ms 28.8±0.2ms 1.3 preprocessing_log.time_highly_variable_genes('pbmc3k', 'off-axis')

Comparison: https://github.com/scverse/scanpy/compare/9d3c340152543a6364d9c55bc11e610027ea319f..e8c5d144270c11de4d9a9afead8dea016dde2d1a
Last changed: Fri, 8 Nov 2024 17:37:15 +0000

More details: https://github.com/scverse/scanpy/pull/3335/checks?check_run_id=32725494502

@flying-sheep flying-sheep marked this pull request as ready for review November 5, 2024 12:57
@flying-sheep flying-sheep modified the milestones: 1.10.4, 1.11.0 Nov 5, 2024
Copy link
Member Author

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost every njitted function seems to have a clear answer to “this was/wasn’t parallel before, should it be now?”

Comment on lines 156 to 177
@cache
def _is_threading_layer_threadsafe() -> bool:
import importlib

import numba

if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None:
# given by direct name
return numba.config.THREADING_LAYER in LAYERS["threadsafe"]

# given by layer type (safe, …)
for layer in cast(list[Layer], numba.config.THREADING_LAYER_PRIORITY):
if layer not in available:
continue
try: # `importlib.util.find_spec` doesn’t work here
importlib.import_module(f"numba.np.ufunc.{layer}pool")
except ImportError:
continue
# the layer has been found
return layer in LAYERS["threadsafe"]
msg = f"No loadable threading layer: {numba.config.THREADING_LAYER=}"
raise ValueError(msg)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This here business is complicated because numba doesn’t support getting the configured threading layer without trying to run something. And that’s exactly what we want to avoid!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/edit: ah the logic isn’t there yet, I’ll fix

n = len(indptr) - 1
result = np.ones(n, dtype=np.bool_)
for i in range(n):
for i in numba.prange(n):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I parallelized this function by replacing the numba.njit import with our wrapper. Seems to work fine!

@@ -182,7 +181,7 @@ def _gearys_c_vec_W(
# https://github.com/numba/numba/issues/6774#issuecomment-788789663


@numba.njit(cache=True)
@numba.njit(cache=True, parallel=False) # noqa: TID251
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an inner function. It wasn’t parallelized before and isn’t now. See also @ivirshup’s comment above

@@ -203,7 +202,7 @@ def _gearys_c_inner_sparse_x_densevec(
return numer / denom


@numba.njit(cache=True)
@numba.njit(cache=True, parallel=False) # noqa: TID251
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -137,7 +137,7 @@ def _morans_i_vec(
return _morans_i_vec_W(g_data, g_indices, g_indptr, x, W)


@numba.njit(cache=True)
@numba.njit(cache=True, parallel=False) # noqa: TID251
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -159,7 +159,7 @@ def _morans_i_vec_W(
return len(x) / W * inum / z2ss


@numba.njit(cache=True)
@numba.njit(cache=True, parallel=False) # noqa: TID251
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines +203 to +204
# parallel=False needed for accuracy
@numba.njit(cache=True, parallel=False) # noqa: TID251
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment her why this one is sequential

Comment on lines +985 to +986
# TODO: can/should this be parallelized?
@numba.njit(cache=True) # noqa: TID251
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you have an idea how to answer this q

@flying-sheep
Copy link
Member Author

flying-sheep commented Nov 5, 2024

Huh weird, it gets detected, but it doesn’t seem to help to call the non-parallel version lol.

If I replace the warn with a print, it’s clear that the correct (non-parallel) function is called from Dask’s thread.

Seems like calling numba from a ThreadPoolExecutor isn’t supported at all, even if it comes from dask.

$ hatch test tests/test_utils.py::test_is_constant_dask[csr_matrix-0] --capture=no
Numba function called from a non-threadsafe context. Try installing `tbb`.
Numba function called from a non-threadsafe context. Try installing `tbb`.

Numba workqueue threading layer is terminating: Concurrent access has been detected.

 - The workqueue threading layer is not threadsafe and may not be accessed concurrently by multiple threads. Concurrent access typically occurs through a nested parallel region launch or by calling Numba parallel=True functions from multiple Python threads.
 - Try using the TBB threading layer as an alternative, as it is, itself, threadsafe. Docs: https://numba.readthedocs.io/en/stable/user/threading-layer.html

Fatal Python error: Aborted

Thread 0x000000016fd2f000 (most recent call first):
  File "~/Dev/scanpy/src/scanpy/_compat.py", line 133 in wrapper
  File "~/Dev/scanpy/src/scanpy/_utils/compute/is_constant.py", line 109 in _
  File "<venv>/lib/python3.12/functools.py", line 909 in wrapper
  File "~/Dev/scanpy/src/scanpy/_utils/compute/is_constant.py", line 30 in func
  File "<venv>/lib/python3.12/site-packages/dask/core.py", line 127 in _execute_task
  File "<venv>/lib/python3.12/site-packages/dask/core.py", line 157 in get
  File "<venv>/lib/python3.12/site-packages/dask/optimization.py", line 1001 in __call__
  File "<venv>/lib/python3.12/site-packages/dask/core.py", line 127 in _execute_task
  File "<venv>/lib/python3.12/site-packages/dask/local.py", line 225 in execute_task
  File "<venv>/lib/python3.12/site-packages/dask/local.py", line 239 in batch_execute_tasks
  File "<venv>/lib/python3.12/concurrent/futures/thread.py", line 64 in run
  File "<venv>/lib/python3.12/concurrent/futures/thread.py", line 92 in _worker
  File "<venv>/lib/python3.12/threading.py", line 1010 in run
  File "<venv>/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
  File "<venv>/lib/python3.12/threading.py", line 1030 in _bootstrap

Thread 0x000000016ed23000 (most recent call first):
  File "<venv>/lib/python3.12/concurrent/futures/thread.py", line 89 in _worker
  File "<venv>/lib/python3.12/threading.py", line 1010 in run
  File "<venv>/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
  File "<venv>/lib/python3.12/threading.py", line 1030 in _bootstrap

Current thread 0x000000016dd17000 (most recent call first):
  File "~/Dev/scanpy/src/scanpy/_compat.py", line 133 in wrapper
  File "~/Dev/scanpy/src/scanpy/_utils/compute/is_constant.py", line 109 in _
  File "<venv>/lib/python3.12/functools.py", line 909 in wrapper
  File "~/Dev/scanpy/src/scanpy/_utils/compute/is_constant.py", line 30 in func
  File "<venv>/lib/python3.12/site-packages/dask/core.py", line 127 in _execute_task
  File "<venv>/lib/python3.12/site-packages/dask/core.py", line 157 in get
  File "<venv>/lib/python3.12/site-packages/dask/optimization.py", line 1001 in __call__
  File "<venv>/lib/python3.12/site-packages/dask/core.py", line 127 in _execute_task
  File "<venv>/lib/python3.12/site-packages/dask/local.py", line 225 in execute_task
  File "<venv>/lib/python3.12/site-packages/dask/local.py", line 239 in batch_execute_tasks
  File "<venv>/lib/python3.12/concurrent/futures/thread.py", line 58 in run
  File "<venv>/lib/python3.12/concurrent/futures/thread.py", line 92 in _worker
  File "<venv>/lib/python3.12/threading.py", line 1010 in run
  File "<venv>/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
  File "<venv>/lib/python3.12/threading.py", line 1030 in _bootstrap

Thread 0x000000016cd0b000 (most recent call first):
  File "<venv>/lib/python3.12/socket.py", line 295 in accept
  File "<venv>/lib/python3.12/site-packages/pytest_rerunfailures.py", line 433 in run_server
  File "<venv>/lib/python3.12/threading.py", line 1010 in run
  File "<venv>/lib/python3.12/threading.py", line 1073 in _bootstrap_inner
  File "<venv>/lib/python3.12/threading.py", line 1030 in _bootstrap

Thread 0x00000001f9bdf240 (most recent call first):
  File "<venv>/lib/python3.12/threading.py", line 355 in wait
  File "<venv>/lib/python3.12/queue.py", line 171 in get
  File "<venv>/lib/python3.12/site-packages/dask/local.py", line 138 in queue_get
  File "<venv>/lib/python3.12/site-packages/dask/local.py", line 501 in get_async
  File "<venv>/lib/python3.12/site-packages/dask/threaded.py", line 90 in get
  File "<venv>/lib/python3.12/site-packages/dask/base.py", line 662 in compute
  File "<venv>/lib/python3.12/site-packages/dask/base.py", line 376 in compute
  File "~/Dev/scanpy/tests/test_utils.py", line 243 in test_is_constant_dask
  File "<venv>/lib/python3.12/site-packages/_pytest/python.py", line 159 in pytest_pyfunc_call
  File "<venv>/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "<venv>/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "<venv>/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "<venv>/lib/python3.12/site-packages/_pytest/python.py", line 1627 in runtest
  File "<venv>/lib/python3.12/site-packages/_pytest/runner.py", line 174 in pytest_runtest_call
  File "<venv>/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "<venv>/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "<venv>/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "<venv>/lib/python3.12/site-packages/_pytest/runner.py", line 242 in <lambda>
  File "<venv>/lib/python3.12/site-packages/_pytest/runner.py", line 341 in from_call
  File "<venv>/lib/python3.12/site-packages/_pytest/runner.py", line 241 in call_and_report
  File "<venv>/lib/python3.12/site-packages/_pytest/runner.py", line 132 in runtestprotocol
  File "<venv>/lib/python3.12/site-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol
  File "<venv>/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "<venv>/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "<venv>/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "<venv>/lib/python3.12/site-packages/_pytest/main.py", line 362 in pytest_runtestloop
  File "<venv>/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "<venv>/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "<venv>/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "<venv>/lib/python3.12/site-packages/_pytest/main.py", line 337 in _main
  File "<venv>/lib/python3.12/site-packages/_pytest/main.py", line 283 in wrap_session
  File "<venv>/lib/python3.12/site-packages/_pytest/main.py", line 330 in pytest_cmdline_main
  File "<venv>/lib/python3.12/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "<venv>/lib/python3.12/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "<venv>/lib/python3.12/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "<venv>/lib/python3.12/site-packages/_pytest/config/__init__.py", line 175 in main
  File "<venv>/lib/python3.12/site-packages/_pytest/config/__init__.py", line 201 in console_main
  File "<venv>/bin/pytest", line 10 in <module>

Copy link
Contributor

@ilan-gold ilan-gold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for jitting comment if it's wrong, otherwise looks good

@flying-sheep flying-sheep enabled auto-merge (squash) November 8, 2024 16:55
@flying-sheep flying-sheep merged commit d0adc25 into main Nov 8, 2024
14 of 15 checks passed
@flying-sheep flying-sheep deleted the safe-njit branch November 8, 2024 17:17
kaushalprasadhial pushed a commit to sanchit-misra/scanpy that referenced this pull request Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants