Skip to content

Commit d924ce4

Browse files
committed
Introduce right_scalar_assert_against_refimpl()
Keeps all refimpl logic near eachother
1 parent 493f669 commit d924ce4

File tree

1 file changed

+60
-34
lines changed

1 file changed

+60
-34
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,53 @@ def binary_assert_against_refimpl(
184184
)
185185

186186

187+
def right_scalar_assert_against_refimpl(
188+
func_name: str,
189+
left: Array,
190+
right: Scalar,
191+
res: Array,
192+
refimpl: Callable[[T, T], T],
193+
expr_template: str = None,
194+
res_stype: Optional[ScalarType] = None,
195+
left_sym: str = "x1",
196+
res_name: str = "out",
197+
filter_: Callable[[Scalar], bool] = default_filter,
198+
strict_check: Optional[bool] = None,
199+
):
200+
if filter_(right):
201+
return # short-circuit here as there will be nothing to test
202+
in_stype = dh.get_scalar_type(left.dtype)
203+
if res_stype is None:
204+
res_stype = in_stype
205+
m, M = dh.dtype_ranges.get(left.dtype, (None, None))
206+
for idx in sh.ndindex(res.shape):
207+
scalar_l = in_stype(left[idx])
208+
if not filter_(scalar_l):
209+
continue
210+
try:
211+
expected = refimpl(scalar_l, right)
212+
except Exception:
213+
continue
214+
if left.dtype != xp.bool:
215+
assert m is not None and M is not None # for mypy
216+
if expected <= m or expected >= M:
217+
continue
218+
scalar_o = res_stype(res[idx])
219+
f_l = sh.fmt_idx(left_sym, idx)
220+
f_o = sh.fmt_idx(res_name, idx)
221+
expr = expr_template.format(f_l, right, expected)
222+
if strict_check == False or dh.is_float_dtype(res.dtype):
223+
assert isclose(scalar_o, expected), (
224+
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
225+
f"{f_l}={scalar_l}"
226+
)
227+
else:
228+
assert scalar_o == expected, (
229+
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
230+
f"{f_l}={scalar_l}"
231+
)
232+
233+
187234
# When appropiate, this module tests operators alongside their respective
188235
# elementwise methods. We do this by parametrizing a generalised test method
189236
# with every relevant method and operator.
@@ -392,40 +439,19 @@ def binary_param_assert_against_refimpl(
392439
):
393440
expr_template = "({} " + op_sym + " {})={}"
394441
if ctx.right_is_scalar:
395-
if filter_(right):
396-
return # short-circuit here as there will be nothing to test
397-
in_stype = dh.get_scalar_type(left.dtype)
398-
if res_stype is None:
399-
res_stype = in_stype
400-
m, M = dh.dtype_ranges.get(left.dtype, (None, None))
401-
for idx in sh.ndindex(res.shape):
402-
scalar_l = in_stype(left[idx])
403-
if not filter_(scalar_l):
404-
continue
405-
try:
406-
expected = refimpl(scalar_l, right)
407-
except Exception:
408-
continue
409-
if left.dtype != xp.bool:
410-
assert m is not None and M is not None # for mypy
411-
if expected <= m or expected >= M:
412-
continue
413-
scalar_o = res_stype(res[idx])
414-
f_l = sh.fmt_idx(ctx.left_sym, idx)
415-
f_o = sh.fmt_idx(ctx.res_name, idx)
416-
expr = expr_template.format(f_l, right, expected)
417-
if strict_check == False or dh.is_float_dtype(res.dtype):
418-
assert isclose(scalar_o, expected), (
419-
f"{f_o}={scalar_o}, but should be roughly {expr} "
420-
f"[{ctx.func_name}()]\n"
421-
f"{f_l}={scalar_l}"
422-
)
423-
else:
424-
assert scalar_o == expected, (
425-
f"{f_o}={scalar_o}, but should be {expr} "
426-
f"[{ctx.func_name}()]\n"
427-
f"{f_l}={scalar_l}"
428-
)
442+
right_scalar_assert_against_refimpl(
443+
func_name=ctx.func_name,
444+
left_sym=ctx.left_sym,
445+
left=left,
446+
right=right,
447+
res_stype=res_stype,
448+
res_name=ctx.res_name,
449+
res=res,
450+
refimpl=refimpl,
451+
expr_template=expr_template,
452+
filter_=filter_,
453+
strict_check=strict_check,
454+
)
429455
else:
430456
binary_assert_against_refimpl(
431457
func_name=ctx.func_name,

0 commit comments

Comments
 (0)