Skip to content

Commit aaf0a7d

Browse files
committed
Test more valid argument signatures in test_arange
1 parent f7fc94b commit aaf0a7d

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"raises",
1212
"doesnt_raise",
1313
"nargs",
14+
"fmt_kw",
1415
"assert_dtype",
1516
"assert_kw_dtype",
1617
"assert_default_float",
@@ -58,7 +59,7 @@ def nargs(func_name):
5859
return len(getfullargspec(getattr(function_stubs, func_name)).args)
5960

6061

61-
def _fmt_kw(kw: Dict[str, Any]) -> str:
62+
def fmt_kw(kw: Dict[str, Any]) -> str:
6263
return ", ".join(f"{k}={v}" for k, v in kw.items())
6364

6465

@@ -120,15 +121,15 @@ def assert_shape(
120121
if isinstance(expected, int):
121122
expected = (expected,)
122123
msg = (
123-
f"out.shape={out_shape}, but should be {expected} [{func_name}({_fmt_kw(kw)})]"
124+
f"out.shape={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
124125
)
125126
assert out_shape == expected, msg
126127

127128

128129
def assert_fill(
129130
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
130131
):
131-
msg = f"out not filled with {fill_value} [{func_name}({_fmt_kw(kw)})]\n{out=}"
132+
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
132133
if math.isnan(fill_value):
133134
assert ah.all(ah.isnan(out)), msg
134135
else:

array_api_tests/test_creation_functions.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,16 @@ def test_arange(dtype, data):
131131
size <= hh.MAX_ARRAY_SIZE
132132
), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE}" # sanity check
133133

134-
kw = data.draw(
135-
hh.specified_kwargs(
136-
hh.KVD("stop", stop, None),
137-
hh.KVD("step", step, None),
138-
hh.KVD("dtype", dtype, None),
139-
),
140-
label="kw",
141-
)
142-
out = xp.arange(start, **kw)
134+
args_samples = [(start, stop), (start, stop, step)]
135+
if stop is None:
136+
args_samples.insert(0, (start,))
137+
args = data.draw(st.sampled_from(args_samples), label="args")
138+
kvds = [hh.KVD("dtype", dtype, None)]
139+
if len(args) != 3:
140+
kvds.insert(0, hh.KVD("step", step, 1))
141+
kwargs = data.draw(hh.specified_kwargs(*kvds), label="kwargs")
142+
143+
out = xp.arange(*args, **kwargs)
143144

144145
if dtype is None:
145146
if all_int:
@@ -148,8 +149,11 @@ def test_arange(dtype, data):
148149
ph.assert_default_float("arange", out.dtype)
149150
else:
150151
ph.assert_dtype("arange", (out.dtype,), dtype)
151-
assert out.ndim == 1, f"{out.ndim=}, but should be 1 [linspace()]"
152-
f_func = f"[arange({start}, {stop}, {step})]"
152+
f_sig = ", ".join(str(n) for n in args)
153+
if len(kwargs) > 0:
154+
f_sig += f", {ph.fmt_kw(kwargs)}"
155+
f_func = f"[arange({f_sig})]"
156+
assert out.ndim == 1, f"{out.ndim=}, but should be 1 [{f_func}]"
153157
# We check size is roughly as expected to avoid edge cases e.g.
154158
#
155159
# >>> xp.arange(2, step=0.333333333333333)

0 commit comments

Comments
 (0)