Skip to content

Commit 3992a6e

Browse files
committed
Style changes
1 parent fb56b22 commit 3992a6e

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,9 @@ def assert_shape(func_name: str, out_shape: Shape, expected: Union[int, Shape],
5151
assert out_shape == expected, msg
5252

5353

54-
5554
def assert_fill(func_name: str, fill: float, dtype: DataType, out: Array, **kw):
5655
f_kw = ", ".join(f"{k}={v}" for k, v in kw.items())
57-
msg = (
58-
f"out not filled with {fill} [{func_name}({f_kw})]\n"
59-
f"{out=}"
60-
)
56+
msg = f"out not filled with {fill} [{func_name}({f_kw})]\n" f"{out=}"
6157
if math.isnan(fill):
6258
assert ah.all(ah.isnan(out)), msg
6359
else:
@@ -96,7 +92,7 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float]
9692
# in https://github.com/HypothesisWorks/hypothesis/issues/2907
9793
st.floats(min_value, max_value, allow_nan=False, allow_infinity=False).filter(
9894
lambda n: float_min <= n <= float_max
99-
)
95+
),
10096
)
10197

10298

@@ -118,9 +114,9 @@ def test_arange(start, dtype, data):
118114
step = data.draw(
119115
st.one_of(
120116
reals(min_value=tol).filter(lambda n: n != 0),
121-
reals(max_value=-tol).filter(lambda n: n != 0)
117+
reals(max_value=-tol).filter(lambda n: n != 0),
122118
),
123-
label="step"
119+
label="step",
124120
)
125121

126122
all_int = all(arg is None or isinstance(arg, int) for arg in [start, stop, step])
@@ -147,11 +143,15 @@ def test_arange(start, dtype, data):
147143
else:
148144
condition = lambda x: x >= _stop
149145
scalar_type = int if dh.is_int_dtype(_dtype) else float
150-
elements = list(scalar_type(n) for n in takewhile(condition, count(_start, step)))
146+
elements = list(
147+
scalar_type(n) for n in takewhile(condition, count(_start, step))
148+
)
151149
else:
152150
elements = []
153151
size = len(elements)
154-
assert size <= hh.MAX_ARRAY_SIZE, f"{size=}, should be no more than {hh.MAX_ARRAY_SIZE=}"
152+
assert (
153+
size <= hh.MAX_ARRAY_SIZE
154+
), f"{size=}, should be no more than {hh.MAX_ARRAY_SIZE=}"
155155

156156
out = xp.arange(start, stop=stop, step=step, dtype=dtype)
157157

@@ -178,7 +178,8 @@ def test_arange(start, dtype, data):
178178
if dh.is_int_dtype(_dtype):
179179
ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype))
180180
else:
181-
pass # TODO: either emulate array module behaviour or assert a rough equals
181+
pass # TODO: either emulate array module behaviour or assert a rough equals
182+
182183

183184
@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))
184185
def test_empty(shape, kw):
@@ -192,7 +193,7 @@ def test_empty(shape, kw):
192193

193194
@given(
194195
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
195-
kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes())
196+
kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()),
196197
)
197198
def test_empty_like(x, kw):
198199
out = xp.empty_like(x, **kw)
@@ -209,7 +210,7 @@ def test_empty_like(x, kw):
209210
kw=hh.kwargs(
210211
k=st.integers(),
211212
dtype=xps.numeric_dtypes(),
212-
)
213+
),
213214
)
214215
def test_eye(n_rows, n_cols, kw):
215216
out = xp.eye(n_rows, n_cols, **kw)
@@ -237,10 +238,11 @@ def test_eye(n_rows, n_cols, kw):
237238
)
238239

239240

240-
241241
@st.composite
242242
def full_fill_values(draw) -> st.SearchStrategy[float]:
243-
kw = draw(st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw"))
243+
kw = draw(
244+
st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw")
245+
)
244246
dtype = kw.get("dtype", None) or draw(default_safe_dtypes)
245247
return draw(xps.from_dtype(dtype))
246248

@@ -262,7 +264,7 @@ def test_full(shape, fill_value, kw):
262264
dtype = dh.default_float
263265
if kw.get("dtype", None) is None:
264266
if isinstance(fill_value, bool):
265-
pass # TODO
267+
pass # TODO
266268
elif isinstance(fill_value, int):
267269
assert_default_int("full", out.dtype)
268270
else:
@@ -275,7 +277,9 @@ def test_full(shape, fill_value, kw):
275277

276278
@st.composite
277279
def full_like_fill_values(draw):
278-
kw = draw(st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_like_kw"))
280+
kw = draw(
281+
st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_like_kw")
282+
)
279283
dtype = kw.get("dtype", None) or draw(hh.shared_dtypes)
280284
return draw(xps.from_dtype(dtype))
281285

@@ -295,6 +299,7 @@ def test_full_like(x, fill_value, kw):
295299
assert_shape("full_like", out.shape, x.shape)
296300
assert_fill("full_like", fill_value, dtype, out, fill_value=fill_value)
297301

302+
298303
finite_kw = {"allow_nan": False, "allow_infinity": False}
299304

300305

@@ -303,10 +308,7 @@ def int_stops(draw, start: int, min_gap: int, m: int, M: int):
303308
sign = draw(st.booleans().map(int))
304309
max_gap = abs(M - m)
305310
max_int = math.floor(math.sqrt(max_gap))
306-
gap = draw(
307-
st.just(0),
308-
st.integers(1, max_int).map(lambda n: min_gap ** n)
309-
)
311+
gap = draw(st.just(0) | st.integers(1, max_int).map(lambda n: min_gap ** n))
310312
stop = start + sign * gap
311313
assume(m <= stop <= M)
312314
return stop

0 commit comments

Comments
 (0)