10
10
from hypothesis import strategies as st
11
11
12
12
from . import _array_module as xp
13
- from . import array_helpers as ah
14
13
from . import dtype_helpers as dh
15
14
from . import hypothesis_helpers as hh
16
15
from . import xps
@@ -127,11 +126,10 @@ def test_stack(shape, dtypes, kw, data):
127
126
]
128
127
129
128
130
- # We apply filters to xps.arrays() so we don't generate array elements that
131
- # are erroneous or undefined for a function/operator.
132
- filters = defaultdict (
133
- lambda : lambda _ : True ,
134
- {func : lambda x : ah .all (x > 0 ) for func in bitwise_shift_funcs },
129
+ # We pass kwargs to the elements strategy used by xps.arrays() so that we don't
130
+ # generate array elements that are erroneous or undefined for a function.
131
+ func_elements = defaultdict (
132
+ lambda : None , {func : {'min_value' : 1 } for func in bitwise_shift_funcs }
135
133
)
136
134
137
135
@@ -178,10 +176,10 @@ def make_id(
178
176
@given (data = st .data ())
179
177
def test_func_promotion (func_name , in_dtypes , out_dtype , data ):
180
178
func = getattr (xp , func_name )
181
- x_filter = filters [func_name ]
179
+ elements = func_elements [func_name ]
182
180
if len (in_dtypes ) == 1 :
183
181
x = data .draw (
184
- xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes ()). filter ( x_filter ),
182
+ xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
185
183
label = 'x' ,
186
184
)
187
185
out = func (x )
@@ -192,7 +190,7 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
192
190
)
193
191
for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
194
192
x = data .draw (
195
- xps .arrays (dtype = dtype , shape = shape ). filter ( x_filter ), label = f'x{ i } '
193
+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f'x{ i } '
196
194
)
197
195
arrays .append (x )
198
196
try :
@@ -301,10 +299,10 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
301
299
@pytest .mark .parametrize ('op, expr, in_dtypes, out_dtype' , op_params )
302
300
@given (data = st .data ())
303
301
def test_op_promotion (op , expr , in_dtypes , out_dtype , data ):
304
- x_filter = filters [ op ]
302
+ elements = func_elements [ func_name ]
305
303
if len (in_dtypes ) == 1 :
306
304
x = data .draw (
307
- xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes ()). filter ( x_filter ),
305
+ xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes (), elements = elements ),
308
306
label = 'x' ,
309
307
)
310
308
out = eval (expr , {'x' : x })
@@ -315,7 +313,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
315
313
)
316
314
for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
317
315
locals_ [f'x{ i } ' ] = data .draw (
318
- xps .arrays (dtype = dtype , shape = shape ). filter ( x_filter ), label = f'x{ i } '
316
+ xps .arrays (dtype = dtype , shape = shape , elements = elements ), label = f'x{ i } '
319
317
)
320
318
try :
321
319
out = eval (expr , locals_ )
@@ -349,12 +347,12 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
349
347
@given (shapes = hh .mutually_broadcastable_shapes (2 ), data = st .data ())
350
348
def test_inplace_op_promotion (op , expr , in_dtypes , out_dtype , shapes , data ):
351
349
assume (len (shapes [0 ]) >= len (shapes [1 ]))
352
- x_filter = filters [ op ]
350
+ elements = func_elements [ func_name ]
353
351
x1 = data .draw (
354
- xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ]). filter ( x_filter ), label = 'x1'
352
+ xps .arrays (dtype = in_dtypes [0 ], shape = shapes [0 ], elements = elements ), label = 'x1'
355
353
)
356
354
x2 = data .draw (
357
- xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ]). filter ( x_filter ), label = 'x2'
355
+ xps .arrays (dtype = in_dtypes [1 ], shape = shapes [1 ], elements = elements ), label = 'x2'
358
356
)
359
357
locals_ = {'x1' : x1 , 'x2' : x2 }
360
358
try :
@@ -386,11 +384,11 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
386
384
@pytest .mark .parametrize ('op, expr, in_dtype, in_stype, out_dtype' , op_scalar_params )
387
385
@given (data = st .data ())
388
386
def test_op_scalar_promotion (op , expr , in_dtype , in_stype , out_dtype , data ):
389
- x_filter = filters [ op ]
387
+ elements = func_elements [ func_name ]
390
388
kw = {k : in_stype is float for k in ('allow_nan' , 'allow_infinity' )}
391
389
s = data .draw (xps .from_dtype (in_dtype , ** kw ).map (in_stype ), label = 'scalar' )
392
390
x = data .draw (
393
- xps .arrays (dtype = in_dtype , shape = hh .shapes ()). filter ( x_filter ), label = 'x'
391
+ xps .arrays (dtype = in_dtype , shape = hh .shapes (), elements = elements ), label = 'x'
394
392
)
395
393
try :
396
394
out = eval (expr , {'x' : x , 's' : s })
@@ -420,11 +418,11 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
420
418
@pytest .mark .parametrize ('op, expr, dtype, in_stype' , inplace_scalar_params )
421
419
@given (data = st .data ())
422
420
def test_inplace_op_scalar_promotion (op , expr , dtype , in_stype , data ):
423
- x_filter = filters [ op ]
421
+ elements = func_elements [ func_name ]
424
422
kw = {k : in_stype is float for k in ('allow_nan' , 'allow_infinity' )}
425
423
s = data .draw (xps .from_dtype (dtype , ** kw ).map (in_stype ), label = 'scalar' )
426
424
x = data .draw (
427
- xps .arrays (dtype = dtype , shape = hh .shapes ()). filter ( x_filter ), label = 'x'
425
+ xps .arrays (dtype = dtype , shape = hh .shapes (), elements = elements ), label = 'x'
428
426
)
429
427
locals_ = {'x' : x , 's' : s }
430
428
try :
0 commit comments