Skip to content

Commit cf31202

Browse files
authored
BUG: reject non-scalar fill_value in full{_like} (#181)
closes gh-55 reviewed at #181
1 parent 95508e8 commit cf31202

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

array_api_strict/_creation_functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,9 @@ def full(
240240
_check_valid_dtype(dtype)
241241
_check_device(device)
242242

243-
if isinstance(fill_value, Array) and fill_value.ndim == 0:
244-
fill_value = fill_value._array
243+
if not isinstance(fill_value, bool | int | float | complex):
244+
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
245+
raise TypeError(msg)
245246
res = np.full(shape, fill_value, dtype=_np_dtype(dtype))
246247
if DType(res.dtype) not in _all_dtypes:
247248
# This will happen if the fill value is not something that NumPy
@@ -270,6 +271,10 @@ def full_like(
270271
if device is None:
271272
device = x.device
272273

274+
if not isinstance(fill_value, bool | int | float | complex):
275+
msg = f"Expected Python scalar fill_value, got type {type(fill_value)}"
276+
raise TypeError(msg)
277+
273278
res = np.full_like(x._array, fill_value, dtype=_np_dtype(dtype))
274279
if DType(res.dtype) not in _all_dtypes:
275280
# This will happen if the fill value is not something that NumPy

array_api_strict/tests/test_creation_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def test_full_errors():
161161
assert_raises(ValueError, lambda: full((1,), 0, device="gpu"))
162162
assert_raises(ValueError, lambda: full((1,), 0, dtype=int))
163163
assert_raises(ValueError, lambda: full((1,), 0, dtype="i"))
164+
assert_raises(TypeError, lambda: full((1,), asarray(0)))
164165

165166

166167
def test_full_like_errors():
@@ -169,6 +170,7 @@ def test_full_like_errors():
169170
assert_raises(ValueError, lambda: full_like(asarray(1), 0, device="gpu"))
170171
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype=int))
171172
assert_raises(ValueError, lambda: full_like(asarray(1), 0, dtype="i"))
173+
assert_raises(TypeError, lambda: full(asarray(1), asarray(0)))
172174

173175

174176
def test_linspace_errors():

0 commit comments

Comments
 (0)