diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 966b6e1..4d808e4 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -645,7 +645,7 @@ def device(x: Array, /) -> Device: to_device : Move array data to a different device. """ - if is_numpy_array(x) or is_ndonnx_array(x): + if is_numpy_array(x): return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type diff --git a/tests/_helpers.py b/tests/_helpers.py index 82623f8..2c8f314 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -21,3 +21,14 @@ def import_(library, wrapper=False): library = 'array_api_compat.' + library return import_module(library) + + +def xfail(request: pytest.FixtureRequest, reason: str) -> None: + """ + XFAIL the currently running test. + + Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately + halting it, so that it may result in a XPASS. + xref https://github.com/pandas-dev/pandas/issues/38902 + """ + request.node.add_marker(pytest.mark.xfail(reason=reason)) diff --git a/tests/test_common.py b/tests/test_common.py index 5bc8e98..32876e6 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -17,7 +17,8 @@ from array_api_compat import ( device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device ) -from ._helpers import import_, wrapped_libraries, all_libraries +from ._helpers import all_libraries, import_, wrapped_libraries, xfail + is_array_functions = { 'numpy': 'is_numpy_array', @@ -188,7 +189,10 @@ class C: @pytest.mark.parametrize("library", all_libraries) -def test_device(library): +def test_device(library, request): + if library == "ndonnx": + xfail(request, reason="Needs ndonnx >=0.9.4") + xp = import_(library, wrapper=True) # We can't test much for device() and to_device() other than that @@ -226,24 +230,19 @@ def test_to_device_host(library): @pytest.mark.parametrize("target_library", is_array_functions.keys()) @pytest.mark.parametrize("source_library", is_array_functions.keys()) def test_asarray_cross_library(source_library, target_library, request): - def _xfail(reason: str) -> None: - # Allow rest of test to execute instead of immediately xfailing - # xref https://github.com/pandas-dev/pandas/issues/38902 - request.node.add_marker(pytest.mark.xfail(reason=reason)) - if source_library == "dask.array" and target_library == "torch": # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved - _xfail(reason="Bug in dask raising error on conversion") + xfail(request, reason="Bug in dask raising error on conversion") elif ( source_library == "ndonnx" and target_library not in ("array_api_strict", "ndonnx", "numpy") ): - _xfail(reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") + xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") elif source_library == "ndonnx" and target_library == "numpy": - _xfail(reason="produces numpy array of ndonnx scalar arrays") + xfail(request, reason="produces numpy array of ndonnx scalar arrays") elif source_library == "jax.numpy" and target_library == "torch": - _xfail(reason="casts int to float") + xfail(request, reason="casts int to float") elif source_library == "cupy" and target_library != "cupy": # cupy explicitly disallows implicit conversions to CPU pytest.skip(reason="cupy does not support implicit conversion to CPU")