Skip to content

Commit f2aad8b

Browse files
committed
Fix test_nan_propagation for immutable arrays.
1 parent 33f2d2e commit f2aad8b

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

array_api_tests/test_special_cases.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
from warnings import warn
2424

2525
import pytest
26-
from hypothesis import given, note, settings
26+
from hypothesis import given, note, settings, assume
2727
from hypothesis import strategies as st
28+
from hypothesis.strategies import composite
2829

2930
from array_api_tests.typing import Array, DataType
3031

@@ -1321,6 +1322,11 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp
13211322
else:
13221323
assert out == expected, msg
13231324

1325+
@composite
1326+
def not_all_false(draw, shape):
1327+
ret = draw(hh.arrays(dtype=hh.bool_dtype, shape=shape))
1328+
assume(xp.any(ret))
1329+
return ret
13241330

13251331
@pytest.mark.parametrize(
13261332
"func_name", [f.__name__ for f in category_to_funcs["statistical"]]
@@ -1331,10 +1337,8 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp
13311337
)
13321338
def test_nan_propagation(func_name, x, data):
13331339
func = getattr(xp, func_name)
1334-
set_idx = data.draw(
1335-
xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx"
1336-
)
1337-
x[set_idx] = float("nan")
1340+
nan_positions = data.draw(not_all_false(x.shape))
1341+
x = xp.where(nan_positions, xp.asarray(float("nan")), x)
13381342
note(f"{x=}")
13391343

13401344
out = func(x)

0 commit comments

Comments
 (0)