25
25
from . import pytest_helpers as ph
26
26
from . import shape_helpers as sh
27
27
from . import xps
28
- from .typing import Array , DataType , Param , Scalar , Shape
28
+ from .typing import Array , DataType , Param , Scalar , ScalarType , Shape
29
29
30
30
pytestmark = pytest .mark .ci
31
31
@@ -38,12 +38,68 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
38
38
return xps .boolean_dtypes () | all_integer_dtypes ()
39
39
40
40
41
- def isclose (n1 : Union [int , float ], n2 : Union [int , float ]):
41
+ def isclose (n1 : Union [int , float ], n2 : Union [int , float ]) -> bool :
42
42
if not (math .isfinite (n1 ) and math .isfinite (n2 )):
43
43
raise ValueError (f"{ n1 = } and { n1 = } , but input must be finite" )
44
44
return math .isclose (n1 , n2 , rel_tol = 0.25 , abs_tol = 1 )
45
45
46
46
47
+ def unary_assert_against_refimpl (
48
+ func_name : str ,
49
+ in_stype : ScalarType ,
50
+ in_ : Array ,
51
+ res : Array ,
52
+ refimpl : Callable [[Scalar ], Scalar ],
53
+ expr_template : str ,
54
+ res_stype : Optional [ScalarType ] = None ,
55
+ ):
56
+ if in_ .shape != res .shape :
57
+ raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
58
+ if res_stype is None :
59
+ res_stype = in_stype
60
+ for idx in sh .ndindex (in_ .shape ):
61
+ scalar_i = in_stype (in_ [idx ])
62
+ expected = refimpl (scalar_i )
63
+ scalar_o = res_stype (res [idx ])
64
+ f_i = sh .fmt_idx ("x" , idx )
65
+ f_o = sh .fmt_idx ("out" , idx )
66
+ expr = expr_template .format (scalar_i , expected )
67
+ assert scalar_o == expected , (
68
+ f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
69
+ f"{ f_i } ={ scalar_i } "
70
+ )
71
+
72
+
73
+ def binary_assert_against_refimpl (
74
+ func_name : str ,
75
+ in_stype : ScalarType ,
76
+ left : Array ,
77
+ right : Array ,
78
+ res : Array ,
79
+ refimpl : Callable [[Scalar , Scalar ], Scalar ],
80
+ expr_template : str ,
81
+ res_stype : Optional [ScalarType ] = None ,
82
+ left_sym : str = "x1" ,
83
+ right_sym : str = "x2" ,
84
+ res_sym : str = "out" ,
85
+ ):
86
+ if res_stype is None :
87
+ res_stype = in_stype
88
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
89
+ scalar_l = in_stype (left [l_idx ])
90
+ scalar_r = in_stype (right [r_idx ])
91
+ expected = refimpl (scalar_l , scalar_r )
92
+ scalar_o = res_stype (res [o_idx ])
93
+ f_l = sh .fmt_idx (left_sym , l_idx )
94
+ f_r = sh .fmt_idx (right_sym , r_idx )
95
+ f_o = sh .fmt_idx (res_sym , o_idx )
96
+ expr = expr_template .format (scalar_l , scalar_r , expected )
97
+ assert scalar_o == expected , (
98
+ f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
99
+ f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
100
+ )
101
+
102
+
47
103
# When appropiate, this module tests operators alongside their respective
48
104
# elementwise methods. We do this by parametrizing a generalised test method
49
105
# with every relevant method and operator.
@@ -1249,53 +1305,45 @@ def test_logical_and(x1, x2):
1249
1305
out = ah .logical_and (x1 , x2 )
1250
1306
ph .assert_dtype ("logical_and" , (x1 .dtype , x2 .dtype ), out .dtype )
1251
1307
ph .assert_result_shape ("logical_and" , (x1 .shape , x2 .shape ), out .shape )
1252
- for l_idx , r_idx , o_idx in sh .iter_indices (x1 .shape , x2 .shape , out .shape ):
1253
- scalar_l = bool (x1 [l_idx ])
1254
- scalar_r = bool (x2 [r_idx ])
1255
- expected = scalar_l and scalar_r
1256
- scalar_o = bool (out [o_idx ])
1257
- f_l = sh .fmt_idx ("x1" , l_idx )
1258
- f_r = sh .fmt_idx ("x2" , r_idx )
1259
- f_o = sh .fmt_idx ("out" , o_idx )
1260
- assert scalar_o == expected , (
1261
- f"{ f_o } ={ scalar_o } , but should be ({ f_l } and { f_r } )={ expected } "
1262
- f"[logical_and()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
1263
- )
1308
+ binary_assert_against_refimpl (
1309
+ "logical_and" ,
1310
+ bool ,
1311
+ x1 ,
1312
+ x2 ,
1313
+ out ,
1314
+ lambda l , r : l and r ,
1315
+ "({} and {})={}" ,
1316
+ )
1264
1317
1265
1318
1266
1319
@given (xps .arrays (dtype = xp .bool , shape = hh .shapes ()))
1267
1320
def test_logical_not (x ):
1268
1321
out = ah .logical_not (x )
1269
1322
ph .assert_dtype ("logical_not" , x .dtype , out .dtype )
1270
1323
ph .assert_shape ("logical_not" , out .shape , x .shape )
1271
- for idx in sh .ndindex (x .shape ):
1272
- assert out [idx ] == (not bool (x [idx ]))
1324
+ unary_assert_against_refimpl (
1325
+ "logical_not" , bool , x , out , lambda i : not i , "(not {})={}"
1326
+ )
1273
1327
1274
1328
1275
1329
@given (* hh .two_mutual_arrays ([xp .bool ]))
1276
1330
def test_logical_or (x1 , x2 ):
1277
1331
out = ah .logical_or (x1 , x2 )
1278
1332
ph .assert_dtype ("logical_or" , (x1 .dtype , x2 .dtype ), out .dtype )
1279
- # See the comments in test_equal
1280
- shape = sh .broadcast_shapes (x1 .shape , x2 .shape )
1281
- ph .assert_shape ("logical_or" , out .shape , shape )
1282
- _x1 = xp .broadcast_to (x1 , shape )
1283
- _x2 = xp .broadcast_to (x2 , shape )
1284
- for idx in sh .ndindex (shape ):
1285
- assert out [idx ] == (bool (_x1 [idx ]) or bool (_x2 [idx ]))
1333
+ ph .assert_result_shape ("logical_or" , (x1 .shape , x2 .shape ), out .shape )
1334
+ binary_assert_against_refimpl (
1335
+ "logical_or" , bool , x1 , x2 , out , lambda l , r : l or r , "({} or {})={}"
1336
+ )
1286
1337
1287
1338
1288
1339
@given (* hh .two_mutual_arrays ([xp .bool ]))
1289
1340
def test_logical_xor (x1 , x2 ):
1290
1341
out = xp .logical_xor (x1 , x2 )
1291
1342
ph .assert_dtype ("logical_xor" , (x1 .dtype , x2 .dtype ), out .dtype )
1292
- # See the comments in test_equal
1293
- shape = sh .broadcast_shapes (x1 .shape , x2 .shape )
1294
- ph .assert_shape ("logical_xor" , out .shape , shape )
1295
- _x1 = xp .broadcast_to (x1 , shape )
1296
- _x2 = xp .broadcast_to (x2 , shape )
1297
- for idx in sh .ndindex (shape ):
1298
- assert out [idx ] == (bool (_x1 [idx ]) ^ bool (_x2 [idx ]))
1343
+ ph .assert_result_shape ("logical_xor" , (x1 .shape , x2 .shape ), out .shape )
1344
+ binary_assert_against_refimpl (
1345
+ "logical_xor" , bool , x1 , x2 , out , lambda l , r : l ^ r , "({} ^ {})={}"
1346
+ )
1299
1347
1300
1348
1301
1349
@pytest .mark .parametrize ("ctx" , make_binary_params ("multiply" , xps .numeric_dtypes ()))
0 commit comments