Description
Functions like equal, greater, and so on (and the operator equivalents) don't allow comparing non-promotable dtypes. This is particularly annoying because it makes it impossible to actually compare uint64 with int64, since the two cannot promote.
>>> import array_api_strict as xp
>>> xp.asarray(0, dtype=xp.int64) < xp.asarray(1, dtype=xp.uint64)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_array_object.py", line 717, in __lt__
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_array_object.py", line 179, in _check_allowed_dtypes
res_dtype = _result_type(self.dtype, other.dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_dtypes.py", line 217, in _result_type
raise TypeError(f"{type1} and {type2} cannot be type promoted together")
TypeError: array_api_strict.int64 and array_api_strict.uint64 cannot be type promoted together
However, the standard doesn't actually say anywhere in greater
or __gt__
that the input types must be promotable:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater.html#greater
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__gt__.html
just that they should be real numeric. So in principle, these operators should even work when comparing floats and integers.
And equal
allows any data type https://data-apis.org/array-api/latest/API_specification/generated/array_api.equal.html#equal, https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__eq__.html
It might be good to get some clarification in the standard about this, for instance, on how ==
should behave for mixing certain dtype combinations.