Skip to content

Commit 4574441

Browse files
committed
Only generate 1D arrays in test_meshgrid, prevent memory errors
1 parent aaf0a7d commit 4574441

File tree

1 file changed

+82
-75
lines changed

1 file changed

+82
-75
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 82 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,29 @@
2424
@given(hh.mutually_promotable_dtypes(None))
2525
def test_result_type(dtypes):
2626
out = xp.result_type(*dtypes)
27-
ph.assert_dtype('result_type', dtypes, out, out_name='out')
27+
ph.assert_dtype("result_type", dtypes, out, out_name="out")
2828

2929

30+
# The number and size of generated arrays is arbitrarily limited to prevent
31+
# meshgrid() running out of memory.
3032
@given(
31-
dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
33+
dtypes=hh.mutually_promotable_dtypes(5, dtypes=dh.numeric_dtypes),
3234
data=st.data(),
3335
)
3436
def test_meshgrid(dtypes, data):
3537
arrays = []
36-
shapes = data.draw(hh.mutually_broadcastable_shapes(len(dtypes)), label='shapes')
38+
shapes = data.draw(
39+
hh.mutually_broadcastable_shapes(
40+
len(dtypes), min_dims=1, max_dims=1, max_side=5
41+
),
42+
label="shapes",
43+
)
3744
for i, (dtype, shape) in enumerate(zip(dtypes, shapes), 1):
38-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
45+
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
3946
arrays.append(x)
4047
out = xp.meshgrid(*arrays)
4148
for i, x in enumerate(out):
42-
ph.assert_dtype('meshgrid', dtypes, x.dtype, out_name=f'out[{i}].dtype')
49+
ph.assert_dtype("meshgrid", dtypes, x.dtype, out_name=f"out[{i}].dtype")
4350

4451

4552
@given(
@@ -50,10 +57,10 @@ def test_meshgrid(dtypes, data):
5057
def test_concat(shape, dtypes, data):
5158
arrays = []
5259
for i, dtype in enumerate(dtypes, 1):
53-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
60+
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
5461
arrays.append(x)
5562
out = xp.concat(arrays)
56-
ph.assert_dtype('concat', dtypes, out.dtype)
63+
ph.assert_dtype("concat", dtypes, out.dtype)
5764

5865

5966
@given(
@@ -64,26 +71,26 @@ def test_concat(shape, dtypes, data):
6471
def test_stack(shape, dtypes, data):
6572
arrays = []
6673
for i, dtype in enumerate(dtypes, 1):
67-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f'x{i}')
74+
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
6875
arrays.append(x)
6976
out = xp.stack(arrays)
70-
ph.assert_dtype('stack', dtypes, out.dtype)
77+
ph.assert_dtype("stack", dtypes, out.dtype)
7178

7279

7380
bitwise_shift_funcs = [
74-
'bitwise_left_shift',
75-
'bitwise_right_shift',
76-
'__lshift__',
77-
'__rshift__',
78-
'__ilshift__',
79-
'__irshift__',
81+
"bitwise_left_shift",
82+
"bitwise_right_shift",
83+
"__lshift__",
84+
"__rshift__",
85+
"__ilshift__",
86+
"__irshift__",
8087
]
8188

8289

8390
# We pass kwargs to the elements strategy used by xps.arrays() so that we don't
8491
# generate array elements that are erroneous or undefined for a function.
8592
func_elements = defaultdict(
86-
lambda: None, {func: {'min_value': 1} for func in bitwise_shift_funcs}
93+
lambda: None, {func: {"min_value": 1} for func in bitwise_shift_funcs}
8794
)
8895

8996

@@ -94,7 +101,7 @@ def make_id(
94101
) -> str:
95102
f_args = dh.fmt_types(in_dtypes)
96103
f_out_dtype = dh.dtype_to_name[out_dtype]
97-
return f'{func_name}({f_args}) -> {f_out_dtype}'
104+
return f"{func_name}({f_args}) -> {f_out_dtype}"
98105

99106

100107
func_params: List[Param[str, Tuple[DataType, ...], DataType]] = []
@@ -128,25 +135,25 @@ def make_id(
128135
raise NotImplementedError()
129136

130137

131-
@pytest.mark.parametrize('func_name, in_dtypes, out_dtype', func_params)
138+
@pytest.mark.parametrize("func_name, in_dtypes, out_dtype", func_params)
132139
@given(data=st.data())
133140
def test_func_promotion(func_name, in_dtypes, out_dtype, data):
134141
func = getattr(xp, func_name)
135142
elements = func_elements[func_name]
136143
if len(in_dtypes) == 1:
137144
x = data.draw(
138145
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements),
139-
label='x',
146+
label="x",
140147
)
141148
out = func(x)
142149
else:
143150
arrays = []
144151
shapes = data.draw(
145-
hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes'
152+
hh.mutually_broadcastable_shapes(len(in_dtypes)), label="shapes"
146153
)
147154
for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1):
148155
x = data.draw(
149-
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f'x{i}'
156+
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f"x{i}"
150157
)
151158
arrays.append(x)
152159
try:
@@ -161,46 +168,46 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data):
161168
p = pytest.param(
162169
(dtype1, dtype2),
163170
promoted_dtype,
164-
id=make_id('', (dtype1, dtype2), promoted_dtype),
171+
id=make_id("", (dtype1, dtype2), promoted_dtype),
165172
)
166173
promotion_params.append(p)
167174

168175

169-
@pytest.mark.parametrize('in_dtypes, out_dtype', promotion_params)
176+
@pytest.mark.parametrize("in_dtypes, out_dtype", promotion_params)
170177
@given(shapes=hh.mutually_broadcastable_shapes(3), data=st.data())
171178
def test_where(in_dtypes, out_dtype, shapes, data):
172-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
173-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
174-
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label='condition')
179+
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
180+
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
181+
cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[2]), label="condition")
175182
out = xp.where(cond, x1, x2)
176-
ph.assert_dtype('where', in_dtypes, out.dtype, out_dtype)
183+
ph.assert_dtype("where", in_dtypes, out.dtype, out_dtype)
177184

178185

179186
numeric_promotion_params = promotion_params[1:]
180187

181188

182-
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
189+
@pytest.mark.parametrize("in_dtypes, out_dtype", numeric_promotion_params)
183190
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=2), data=st.data())
184191
def test_tensordot(in_dtypes, out_dtype, shapes, data):
185-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
186-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
192+
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
193+
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
187194
out = xp.tensordot(x1, x2)
188-
ph.assert_dtype('tensordot', in_dtypes, out.dtype, out_dtype)
195+
ph.assert_dtype("tensordot", in_dtypes, out.dtype, out_dtype)
189196

190197

191-
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_params)
198+
@pytest.mark.parametrize("in_dtypes, out_dtype", numeric_promotion_params)
192199
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1), data=st.data())
193200
def test_vecdot(in_dtypes, out_dtype, shapes, data):
194-
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
195-
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
201+
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label="x1")
202+
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label="x2")
196203
out = xp.vecdot(x1, x2)
197-
ph.assert_dtype('vecdot', in_dtypes, out.dtype, out_dtype)
204+
ph.assert_dtype("vecdot", in_dtypes, out.dtype, out_dtype)
198205

199206

200207
op_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
201208
op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol}
202209
for op, symbol in op_to_symbol.items():
203-
if op == '__matmul__':
210+
if op == "__matmul__":
204211
continue
205212
valid_in_dtypes = dh.func_in_dtypes[op]
206213
ndtypes = ph.nargs(op)
@@ -209,7 +216,7 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
209216
out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype
210217
p = pytest.param(
211218
op,
212-
f'{symbol}x',
219+
f"{symbol}x",
213220
(in_dtype,),
214221
out_dtype,
215222
id=make_id(op, (in_dtype,), out_dtype),
@@ -221,42 +228,42 @@ def test_vecdot(in_dtypes, out_dtype, shapes, data):
221228
out_dtype = xp.bool if dh.func_returns_bool[op] else promoted_dtype
222229
p = pytest.param(
223230
op,
224-
f'x1 {symbol} x2',
231+
f"x1 {symbol} x2",
225232
(in_dtype1, in_dtype2),
226233
out_dtype,
227234
id=make_id(op, (in_dtype1, in_dtype2), out_dtype),
228235
)
229236
op_params.append(p)
230237
# We generate params for abs seperately as it does not have an associated symbol
231-
for in_dtype in dh.func_in_dtypes['__abs__']:
238+
for in_dtype in dh.func_in_dtypes["__abs__"]:
232239
p = pytest.param(
233-
'__abs__',
234-
'abs(x)',
240+
"__abs__",
241+
"abs(x)",
235242
(in_dtype,),
236243
in_dtype,
237-
id=make_id('__abs__', (in_dtype,), in_dtype),
244+
id=make_id("__abs__", (in_dtype,), in_dtype),
238245
)
239246
op_params.append(p)
240247

241248

242-
@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', op_params)
249+
@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", op_params)
243250
@given(data=st.data())
244251
def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
245252
elements = func_elements[func_name]
246253
if len(in_dtypes) == 1:
247254
x = data.draw(
248255
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements),
249-
label='x',
256+
label="x",
250257
)
251-
out = eval(expr, {'x': x})
258+
out = eval(expr, {"x": x})
252259
else:
253260
locals_ = {}
254261
shapes = data.draw(
255-
hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes'
262+
hh.mutually_broadcastable_shapes(len(in_dtypes)), label="shapes"
256263
)
257264
for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1):
258-
locals_[f'x{i}'] = data.draw(
259-
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f'x{i}'
265+
locals_[f"x{i}"] = data.draw(
266+
xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f"x{i}"
260267
)
261268
try:
262269
out = eval(expr, locals_)
@@ -267,7 +274,7 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
267274

268275
inplace_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = []
269276
for op, symbol in dh.inplace_op_to_symbol.items():
270-
if op == '__imatmul__':
277+
if op == "__imatmul__":
271278
continue
272279
valid_in_dtypes = dh.func_in_dtypes[op]
273280
for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items():
@@ -278,44 +285,44 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data):
278285
):
279286
p = pytest.param(
280287
op,
281-
f'x1 {symbol} x2',
288+
f"x1 {symbol} x2",
282289
(in_dtype1, in_dtype2),
283290
promoted_dtype,
284291
id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype),
285292
)
286293
inplace_params.append(p)
287294

288295

289-
@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', inplace_params)
296+
@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", inplace_params)
290297
@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data())
291298
def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
292299
assume(len(shapes[0]) >= len(shapes[1]))
293300
elements = func_elements[func_name]
294301
x1 = data.draw(
295-
xps.arrays(dtype=in_dtypes[0], shape=shapes[0], elements=elements), label='x1'
302+
xps.arrays(dtype=in_dtypes[0], shape=shapes[0], elements=elements), label="x1"
296303
)
297304
x2 = data.draw(
298-
xps.arrays(dtype=in_dtypes[1], shape=shapes[1], elements=elements), label='x2'
305+
xps.arrays(dtype=in_dtypes[1], shape=shapes[1], elements=elements), label="x2"
299306
)
300-
locals_ = {'x1': x1, 'x2': x2}
307+
locals_ = {"x1": x1, "x2": x2}
301308
try:
302309
exec(expr, locals_)
303310
except OverflowError:
304311
reject()
305-
x1 = locals_['x1']
306-
ph.assert_dtype(op, in_dtypes, x1.dtype, out_dtype, out_name='x1.dtype')
312+
x1 = locals_["x1"]
313+
ph.assert_dtype(op, in_dtypes, x1.dtype, out_dtype, out_name="x1.dtype")
307314

308315

309316
op_scalar_params: List[Param[str, str, DataType, ScalarType, DataType]] = []
310317
for op, symbol in dh.binary_op_to_symbol.items():
311-
if op == '__matmul__':
318+
if op == "__matmul__":
312319
continue
313320
for in_dtype in dh.func_in_dtypes[op]:
314321
out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype
315322
for in_stype in dh.dtype_to_scalars[in_dtype]:
316323
p = pytest.param(
317324
op,
318-
f'x {symbol} s',
325+
f"x {symbol} s",
319326
in_dtype,
320327
in_stype,
321328
out_dtype,
@@ -324,57 +331,57 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data):
324331
op_scalar_params.append(p)
325332

326333

327-
@pytest.mark.parametrize('op, expr, in_dtype, in_stype, out_dtype', op_scalar_params)
334+
@pytest.mark.parametrize("op, expr, in_dtype, in_stype, out_dtype", op_scalar_params)
328335
@given(data=st.data())
329336
def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data):
330337
elements = func_elements[func_name]
331-
kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')}
332-
s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar')
338+
kw = {k: in_stype is float for k in ("allow_nan", "allow_infinity")}
339+
s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label="scalar")
333340
x = data.draw(
334-
xps.arrays(dtype=in_dtype, shape=hh.shapes(), elements=elements), label='x'
341+
xps.arrays(dtype=in_dtype, shape=hh.shapes(), elements=elements), label="x"
335342
)
336343
try:
337-
out = eval(expr, {'x': x, 's': s})
344+
out = eval(expr, {"x": x, "s": s})
338345
except OverflowError:
339346
reject()
340347
ph.assert_dtype(op, (in_dtype, in_stype), out.dtype, out_dtype)
341348

342349

343350
inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = []
344351
for op, symbol in dh.inplace_op_to_symbol.items():
345-
if op == '__imatmul__':
352+
if op == "__imatmul__":
346353
continue
347354
for dtype in dh.func_in_dtypes[op]:
348355
for in_stype in dh.dtype_to_scalars[dtype]:
349356
p = pytest.param(
350357
op,
351-
f'x {symbol} s',
358+
f"x {symbol} s",
352359
dtype,
353360
in_stype,
354361
id=make_id(op, (dtype, in_stype), dtype),
355362
)
356363
inplace_scalar_params.append(p)
357364

358365

359-
@pytest.mark.parametrize('op, expr, dtype, in_stype', inplace_scalar_params)
366+
@pytest.mark.parametrize("op, expr, dtype, in_stype", inplace_scalar_params)
360367
@given(data=st.data())
361368
def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data):
362369
elements = func_elements[func_name]
363-
kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')}
364-
s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar')
370+
kw = {k: in_stype is float for k in ("allow_nan", "allow_infinity")}
371+
s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label="scalar")
365372
x = data.draw(
366-
xps.arrays(dtype=dtype, shape=hh.shapes(), elements=elements), label='x'
373+
xps.arrays(dtype=dtype, shape=hh.shapes(), elements=elements), label="x"
367374
)
368-
locals_ = {'x': x, 's': s}
375+
locals_ = {"x": x, "s": s}
369376
try:
370377
exec(expr, locals_)
371378
except OverflowError:
372379
reject()
373-
x = locals_['x']
374-
assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}'
375-
ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, out_name='x.dtype')
380+
x = locals_["x"]
381+
assert x.dtype == dtype, f"{x.dtype=!s}, but should be {dtype}"
382+
ph.assert_dtype(op, (dtype, in_stype), x.dtype, dtype, out_name="x.dtype")
376383

377384

378-
if __name__ == '__main__':
385+
if __name__ == "__main__":
379386
for (i, j), p in dh.promotion_table.items():
380-
print(f'({i}, {j}) -> {p}')
387+
print(f"({i}, {j}) -> {p}")

0 commit comments

Comments
 (0)