@@ -184,6 +184,53 @@ def binary_assert_against_refimpl(
184
184
)
185
185
186
186
187
+ def right_scalar_assert_against_refimpl (
188
+ func_name : str ,
189
+ left : Array ,
190
+ right : Scalar ,
191
+ res : Array ,
192
+ refimpl : Callable [[T , T ], T ],
193
+ expr_template : str = None ,
194
+ res_stype : Optional [ScalarType ] = None ,
195
+ left_sym : str = "x1" ,
196
+ res_name : str = "out" ,
197
+ filter_ : Callable [[Scalar ], bool ] = default_filter ,
198
+ strict_check : Optional [bool ] = None ,
199
+ ):
200
+ if filter_ (right ):
201
+ return # short-circuit here as there will be nothing to test
202
+ in_stype = dh .get_scalar_type (left .dtype )
203
+ if res_stype is None :
204
+ res_stype = in_stype
205
+ m , M = dh .dtype_ranges .get (left .dtype , (None , None ))
206
+ for idx in sh .ndindex (res .shape ):
207
+ scalar_l = in_stype (left [idx ])
208
+ if not filter_ (scalar_l ):
209
+ continue
210
+ try :
211
+ expected = refimpl (scalar_l , right )
212
+ except Exception :
213
+ continue
214
+ if left .dtype != xp .bool :
215
+ assert m is not None and M is not None # for mypy
216
+ if expected <= m or expected >= M :
217
+ continue
218
+ scalar_o = res_stype (res [idx ])
219
+ f_l = sh .fmt_idx (left_sym , idx )
220
+ f_o = sh .fmt_idx (res_name , idx )
221
+ expr = expr_template .format (f_l , right , expected )
222
+ if strict_check == False or dh .is_float_dtype (res .dtype ):
223
+ assert isclose (scalar_o , expected ), (
224
+ f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
225
+ f"{ f_l } ={ scalar_l } "
226
+ )
227
+ else :
228
+ assert scalar_o == expected , (
229
+ f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
230
+ f"{ f_l } ={ scalar_l } "
231
+ )
232
+
233
+
187
234
# When appropiate, this module tests operators alongside their respective
188
235
# elementwise methods. We do this by parametrizing a generalised test method
189
236
# with every relevant method and operator.
@@ -392,40 +439,19 @@ def binary_param_assert_against_refimpl(
392
439
):
393
440
expr_template = "({} " + op_sym + " {})={}"
394
441
if ctx .right_is_scalar :
395
- if filter_ (right ):
396
- return # short-circuit here as there will be nothing to test
397
- in_stype = dh .get_scalar_type (left .dtype )
398
- if res_stype is None :
399
- res_stype = in_stype
400
- m , M = dh .dtype_ranges .get (left .dtype , (None , None ))
401
- for idx in sh .ndindex (res .shape ):
402
- scalar_l = in_stype (left [idx ])
403
- if not filter_ (scalar_l ):
404
- continue
405
- try :
406
- expected = refimpl (scalar_l , right )
407
- except Exception :
408
- continue
409
- if left .dtype != xp .bool :
410
- assert m is not None and M is not None # for mypy
411
- if expected <= m or expected >= M :
412
- continue
413
- scalar_o = res_stype (res [idx ])
414
- f_l = sh .fmt_idx (ctx .left_sym , idx )
415
- f_o = sh .fmt_idx (ctx .res_name , idx )
416
- expr = expr_template .format (f_l , right , expected )
417
- if strict_check == False or dh .is_float_dtype (res .dtype ):
418
- assert isclose (scalar_o , expected ), (
419
- f"{ f_o } ={ scalar_o } , but should be roughly { expr } "
420
- f"[{ ctx .func_name } ()]\n "
421
- f"{ f_l } ={ scalar_l } "
422
- )
423
- else :
424
- assert scalar_o == expected , (
425
- f"{ f_o } ={ scalar_o } , but should be { expr } "
426
- f"[{ ctx .func_name } ()]\n "
427
- f"{ f_l } ={ scalar_l } "
428
- )
442
+ right_scalar_assert_against_refimpl (
443
+ func_name = ctx .func_name ,
444
+ left_sym = ctx .left_sym ,
445
+ left = left ,
446
+ right = right ,
447
+ res_stype = res_stype ,
448
+ res_name = ctx .res_name ,
449
+ res = res ,
450
+ refimpl = refimpl ,
451
+ expr_template = expr_template ,
452
+ filter_ = filter_ ,
453
+ strict_check = strict_check ,
454
+ )
429
455
else :
430
456
binary_assert_against_refimpl (
431
457
func_name = ctx .func_name ,
0 commit comments