Skip to content

Commit b6d05da

Browse files
committed
Favour lists for ph.assert_result_shape()
1 parent 66a1fd4 commit b6d05da

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def assert_shape(
149149

150150
def assert_result_shape(
151151
func_name: str,
152-
in_shapes: Tuple[Shape],
152+
in_shapes: Sequence[Shape],
153153
out_shape: Shape,
154154
/,
155155
expected: Optional[Shape] = None,

array_api_tests/test_manipulation_functions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_expand_dims(x, axis):
142142
index = axis if axis >= 0 else x.ndim + axis + 1
143143
shape.insert(index, 1)
144144
shape = tuple(shape)
145-
ph.assert_result_shape("expand_dims", (x.shape,), out.shape, shape)
145+
ph.assert_result_shape("expand_dims", [x.shape], out.shape, shape)
146146

147147
assert_array_ndindex(
148148
"expand_dims", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)
@@ -181,7 +181,7 @@ def test_squeeze(x, data):
181181
if i not in axes:
182182
shape.append(side)
183183
shape = tuple(shape)
184-
ph.assert_result_shape("squeeze", (x.shape,), out.shape, shape, axis=axis)
184+
ph.assert_result_shape("squeeze", [x.shape], out.shape, shape, axis=axis)
185185

186186
assert_array_ndindex("squeeze", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape))
187187

@@ -230,7 +230,7 @@ def test_permute_dims(x, axes):
230230
side = x.shape[dim]
231231
shape[i] = side
232232
shape = tuple(shape)
233-
ph.assert_result_shape("permute_dims", (x.shape,), out.shape, shape, axes=axes)
233+
ph.assert_result_shape("permute_dims", [x.shape], out.shape, shape, axes=axes)
234234

235235
indices = list(sh.ndindex(x.shape))
236236
permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices]
@@ -265,7 +265,7 @@ def test_reshape(x, data):
265265
rsize = math.prod(shape) * -1
266266
_shape[shape.index(-1)] = size / rsize
267267
_shape = tuple(_shape)
268-
ph.assert_result_shape("reshape", (x.shape,), out.shape, _shape, shape=shape)
268+
ph.assert_result_shape("reshape", [x.shape], out.shape, _shape, shape=shape)
269269

270270
assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape))
271271

@@ -303,7 +303,7 @@ def test_roll(x, data):
303303

304304
ph.assert_dtype("roll", x.dtype, out.dtype)
305305

306-
ph.assert_result_shape("roll", (x.shape,), out.shape)
306+
ph.assert_result_shape("roll", [x.shape], out.shape)
307307

308308
if kw.get("axis", None) is None:
309309
assert isinstance(shift, int) # sanity check

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,9 @@ def assert_binary_param_shape(
287287
expected: Optional[Shape] = None,
288288
):
289289
if ctx.right_is_scalar:
290-
in_shapes = (left.shape,)
290+
in_shapes = [left.shape]
291291
else:
292-
in_shapes = (left.shape, right.shape) # type: ignore
292+
in_shapes = [left.shape, right.shape] # type: ignore
293293
ph.assert_result_shape(
294294
ctx.func_name, in_shapes, res.shape, expected, repr_name=f"{ctx.res_name}.shape"
295295
)
@@ -444,7 +444,7 @@ def test_atan(x):
444444
def test_atan2(x1, x2):
445445
out = xp.atan2(x1, x2)
446446
ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype)
447-
ph.assert_result_shape("atan2", (x1.shape, x2.shape), out.shape)
447+
ph.assert_result_shape("atan2", [x1.shape, x2.shape], out.shape)
448448
INFINITY1 = ah.infinity(x1.shape, x1.dtype)
449449
INFINITY2 = ah.infinity(x2.shape, x2.dtype)
450450
PI = ah.π(out.shape, out.dtype)
@@ -1304,7 +1304,7 @@ def test_logaddexp(x1, x2):
13041304
def test_logical_and(x1, x2):
13051305
out = ah.logical_and(x1, x2)
13061306
ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype)
1307-
ph.assert_result_shape("logical_and", (x1.shape, x2.shape), out.shape)
1307+
ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape)
13081308
binary_assert_against_refimpl(
13091309
"logical_and",
13101310
bool,
@@ -1330,7 +1330,7 @@ def test_logical_not(x):
13301330
def test_logical_or(x1, x2):
13311331
out = ah.logical_or(x1, x2)
13321332
ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype)
1333-
ph.assert_result_shape("logical_or", (x1.shape, x2.shape), out.shape)
1333+
ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape)
13341334
binary_assert_against_refimpl(
13351335
"logical_or", bool, x1, x2, out, lambda l, r: l or r, "({} or {})={}"
13361336
)
@@ -1340,7 +1340,7 @@ def test_logical_or(x1, x2):
13401340
def test_logical_xor(x1, x2):
13411341
out = xp.logical_xor(x1, x2)
13421342
ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype)
1343-
ph.assert_result_shape("logical_xor", (x1.shape, x2.shape), out.shape)
1343+
ph.assert_result_shape("logical_xor", [x1.shape, x2.shape], out.shape)
13441344
binary_assert_against_refimpl(
13451345
"logical_xor", bool, x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}"
13461346
)

0 commit comments

Comments
 (0)