23
23
from warnings import warn
24
24
25
25
import pytest
26
- from hypothesis import given , note , settings
26
+ from hypothesis import given , note , settings , assume
27
27
from hypothesis import strategies as st
28
+ from hypothesis .strategies import composite
28
29
29
30
from array_api_tests .typing import Array , DataType
30
31
@@ -1321,6 +1322,11 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp
1321
1322
else :
1322
1323
assert out == expected , msg
1323
1324
1325
+ @composite
1326
+ def not_all_false (draw , shape ):
1327
+ ret = draw (hh .arrays (dtype = hh .bool_dtype , shape = shape ))
1328
+ assume (xp .any (ret ))
1329
+ return ret
1324
1330
1325
1331
@pytest .mark .parametrize (
1326
1332
"func_name" , [f .__name__ for f in category_to_funcs ["statistical" ]]
@@ -1331,10 +1337,8 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp
1331
1337
)
1332
1338
def test_nan_propagation (func_name , x , data ):
1333
1339
func = getattr (xp , func_name )
1334
- set_idx = data .draw (
1335
- xps .indices (x .shape , max_dims = 0 , allow_ellipsis = False ), label = "set idx"
1336
- )
1337
- x [set_idx ] = float ("nan" )
1340
+ nan_positions = data .draw (not_all_false (x .shape ))
1341
+ x = xp .where (nan_positions , xp .asarray (float ("nan" )), x )
1338
1342
note (f"{ x = } " )
1339
1343
1340
1344
out = func (x )
0 commit comments