Skip to content

Commit a25dfa8

Browse files
committed
Replace array filtering with xps.from_dtype() kwargs
1 parent 4061965 commit a25dfa8

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from hypothesis import strategies as st
1111

1212
from . import _array_module as xp
13-
from . import array_helpers as ah
1413
from . import dtype_helpers as dh
1514
from . import hypothesis_helpers as hh
1615
from . import xps
@@ -127,11 +126,10 @@ def test_stack(shape, dtypes, kw, data):
127126
]
128127

129128

130-
# We apply filters to xps.arrays() so we don't generate array elements that
131-
# are erroneous or undefined for a function/operator.
132-
filters = defaultdict(
133-
lambda: lambda _: True,
134-
{func: lambda x: ah.all(x > 0) for func in bitwise_shift_funcs},
129+
# We pass kwargs to the elements strategy used by xps.arrays() so that we don't
130+
# generate array elements that are erroneous or undefined for a function.
131+
func_elements = defaultdict(
132+
lambda: None, {func: {'min_value': 1} for func in bitwise_shift_funcs}
135133
)
136134

137135

@@ -178,10 +176,10 @@ def make_id(
178176
@given(data=st.data())
179177
def test_func_promotion(func_name, in_dtypes, out_dtype, data):
180178
func = getattr(xp, func_name)
181-
x_filter = filters[func_name]
179+
elements = func_elements[func_name]
182180
if len(in_dtypes) == 1:
183181
x = data.draw(
184-
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes()).filter(x_filter),
182+
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements),
185183
label='x',
186184
)
187185
out = func(x)
@@ -192,7 +190,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
192190
)
193191
for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1):
194192
x = data.draw(
195-
xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}'
193+
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f'x{i}'
196194
)
197195
arrays.append(x)
198196
try:
@@ -301,10 +299,10 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
301299
@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', op_params)
302300
@given(data=st.data())
303301
def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
304-
x_filter = filters[op]
302+
elements = func_elements[func_name]
305303
if len(in_dtypes) == 1:
306304
x = data.draw(
307-
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes()).filter(x_filter),
305+
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements),
308306
label='x',
309307
)
310308
out = eval(expr, {'x': x})
@@ -315,7 +313,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
315313
)
316314
for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1):
317315
locals_[f'x{i}'] = data.draw(
318-
xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}'
316+
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f'x{i}'
319317
)
320318
try:
321319
out = eval(expr, locals_)
@@ -349,12 +347,12 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
349347
@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data())
350348
def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
351349
assume(len(shapes[0]) >= len(shapes[1]))
352-
x_filter = filters[op]
350+
elements = func_elements[func_name]
353351
x1 = data.draw(
354-
xps.arrays(dtype=in_dtypes[0], shape=shapes[0]).filter(x_filter), label='x1'
352+
xps.arrays(dtype=in_dtypes[0], shape=shapes[0], elements=elements), label='x1'
355353
)
356354
x2 = data.draw(
357-
xps.arrays(dtype=in_dtypes[1], shape=shapes[1]).filter(x_filter), label='x2'
355+
xps.arrays(dtype=in_dtypes[1], shape=shapes[1], elements=elements), label='x2'
358356
)
359357
locals_ = {'x1': x1, 'x2': x2}
360358
try:
@@ -386,11 +384,11 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
386384
@pytest.mark.parametrize('op, expr, in_dtype, in_stype, out_dtype', op_scalar_params)
387385
@given(data=st.data())
388386
def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
389-
x_filter = filters[op]
387+
elements = func_elements[func_name]
390388
kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')}
391389
s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar')
392390
x = data.draw(
393-
xps.arrays(dtype=in_dtype, shape=hh.shapes()).filter(x_filter), label='x'
391+
xps.arrays(dtype=in_dtype, shape=hh.shapes(), elements=elements), label='x'
394392
)
395393
try:
396394
out = eval(expr, {'x': x, 's': s})
@@ -420,11 +418,11 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
420418
@pytest.mark.parametrize('op, expr, dtype, in_stype', inplace_scalar_params)
421419
@given(data=st.data())
422420
def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
423-
x_filter = filters[op]
421+
elements = func_elements[func_name]
424422
kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')}
425423
s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar')
426424
x = data.draw(
427-
xps.arrays(dtype=dtype, shape=hh.shapes()).filter(x_filter), label='x'
425+
xps.arrays(dtype=dtype, shape=hh.shapes(), elements=elements), label='x'
428426
)
429427
locals_ = {'x': x, 's': s}
430428
try:

0 commit comments

Comments
 (0)