Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support 2024.12 #223

Open
3 of 9 tasks
ev-br opened this issue Dec 27, 2024 · 5 comments
Open
3 of 9 tasks

Support 2024.12 #223

ev-br opened this issue Dec 27, 2024 · 5 comments

Comments

@ev-br
Copy link
Member

ev-br commented Dec 27, 2024

This is a tracking issue for the 2024.12 revision support in array-api-compat.

The 2024.12 is in a draft still, so this list below is preliminary:

@crusaderky
Copy link
Contributor

take_along_axis in dask is a very complicated matter when there are multiple chunks: dask/dask#3663

@mdhaber
Copy link

mdhaber commented Feb 1, 2025

Can we start implementing the draft standard in array API compat, especially result_type?

Before scalars were allowed in array_namespace and result_type, I don't think it would be unusual do something like:

xp = array_namespace(x, y)
x = xp.asarray(x)
y = xp.asarray(y)
dtype = result_type(x, y)

especially since the following would not be acceptable according to the standard if one of x or y can be a Python scalar:

xp = array_namespace(x, y)
dtype = result_type(x, y)  # fails for array_api_strict if `x` or `y` is a Python scalar

But if the intent of result_type is to find the correct result dtype of an operation between x and y, the former can produce an incorrect result when scalars are involved, e.g.:

from array_api_compat import array_namespace, numpy as np
x = np.asarray([1, 2, 3], dtype=np.float32)
y = 3.

xp = array_namespace(x, y)
x_ = np.asarray(x)  # float32
y_ = xp.asarray(y)  # float64

xp.result_type(x_, y_)  # float64
(x*y).dtype  # float32

So currently, we are between a rock and a hard place: the code we might naturally write now (and code that was OK before scalars would be accepted) is not really correct now that scalars can be involved.

Of course there are workarounds, but rather than writing workarounds and simplifying them once the standard is published and array_api_compat is updated, it would save quite a bit of work (and/or avoid releases with suboptimal dtype behavior) if array_api_compat were to support the code we will want at the end of 2025.

@ev-br
Copy link
Member Author

ev-br commented Feb 1, 2025

Yes, the work is ongoing, help appreciated :-). The aim is indeed to have -compat and -strict releases to support 2024.12 soon after the spec is finalized.

The current state is:

In [12]: import array_api_strict as xp

In [13]: xp.result_type(xp.ones(3, dtype=xp.float32), 3)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[13], line 1
----> 1 xp.result_type(xp.ones(3, dtype=xp.float32), 3)

File ~/repos/array-api-strict/array_api_strict/_data_type_functions.py:238, in result_type(*arrays_and_dtypes)
    235     return result
    237 if get_array_api_strict_flags()['api_version'] <= '2023.12':
--> 238     raise TypeError("result_type() inputs must be array_api arrays or dtypes")
    240 # promote python scalars given the result_type for all arrays/dtypes
    241 from ._creation_functions import empty

TypeError: result_type() inputs must be array_api arrays or dtypes

In [14]: xp.set_array_api_strict_flags(api_version='2024.12')
/home/br/repos/array-api-strict/array_api_strict/_flags.py:144: UserWarning: The 2024.12 version of the array API specification is in draft status. Not all features are implemented in array_api_strict, some functions may not be fully tested, and behaviors are subject to change before the final standard release.
  warnings.warn(f"The {draft_version} version of the array API specification is in draft status. Not all features are implemented in array_api_strict, some functions may not be fully tested, and behaviors are subject to change before the final standard release.")

In [15]: xp.result_type(xp.ones(3, dtype=xp.float32), 3)
Out[15]: array_api_strict.float32

In [16]: xp.result_type(xp.ones(3, dtype=xp.uint8), 3)
Out[16]: array_api_strict.uint8
  • array-api-compat: draft implementation at Add draft support for 2024.12 revision #224

  • array-api-tests: not much of 2024.12 is tested, so help here would be most welcome. Especially so if you like working with hypothesis :-).

@mdhaber
Copy link

mdhaber commented Feb 1, 2025

Great. I was tempted to update result_type but wasn't sure if it was fair game yet. I'll post here if I start to work on it (but wouldn't want that to discourage anyone else in the meantime).

@ev-br
Copy link
Member Author

ev-br commented Feb 1, 2025

I'd say it's totally fair game. We won't merge PRs to -strict and -compat until a matching spec PR lands, that's pretty much the only boundary condition.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants