Skip to content

Commit 8f95986

Browse files
committed
Fixed test_full_like generation to match spec
1 parent 4c68483 commit 8f95986

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
scalars, xps, shared_optional_promotable_dtypes)
99

1010
from hypothesis import assume, given
11-
from hypothesis.strategies import integers, floats, one_of, none, booleans, just
11+
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite
1212

1313

1414

@@ -152,30 +152,36 @@ def test_full(shape, fill_value, dtype):
152152
else:
153153
assert all(equal(a, asarray(fill_value, **kwargs))), "full() array did not equal the fill value"
154154

155+
shared_optional_dtypes = shared(none() | shared_dtypes, key="optional_dtype")
156+
157+
@composite
158+
def fill_value(draw):
159+
dtype = draw(shared_optional_dtypes)
160+
if dtype is None:
161+
dtype = draw(shared_dtypes)
162+
return draw(xps.from_dtype(dtype))
163+
155164
@given(
156-
a=xps.arrays(
157-
dtype=shared_dtypes,
158-
shape=shapes,
159-
),
160-
fill_value=promotable_dtypes(shared_dtypes).flatmap(xps.from_dtype),
161-
dtype=shared_optional_promotable_dtypes,
165+
x=xps.arrays(dtype=shared_dtypes, shape=shapes),
166+
fill_value=fill_value(),
167+
dtype=shared_optional_dtypes,
162168
)
163-
def test_full_like(a, fill_value, dtype):
169+
def test_full_like(x, fill_value, dtype):
164170
kwargs = {} if dtype is None else {'dtype': dtype}
165171

166-
a_like = full_like(a, fill_value, **kwargs)
172+
x_like = full_like(x, fill_value, **kwargs)
167173

168174
if dtype is None:
169175
# TODO: Should it actually match a.dtype?
170176
pass
171177
else:
172-
assert a_like.dtype == dtype
178+
assert x_like.dtype == dtype
173179

174-
assert a_like.shape == a.shape, "full_like() produced an array with incorrect shape"
175-
if is_float_dtype(a_like.dtype) and isnan(asarray(fill_value)):
176-
assert all(isnan(a_like)), "full_like() array did not equal the fill value"
180+
assert x_like.shape == x.shape, "full_like() produced an array with incorrect shape"
181+
if is_float_dtype(x_like.dtype) and isnan(asarray(fill_value)):
182+
assert all(isnan(x_like)), "full_like() array did not equal the fill value"
177183
else:
178-
assert all(equal(a_like, asarray(fill_value, dtype=a_like.dtype))), "full_like() array did not equal the fill value"
184+
assert all(equal(x_like, asarray(fill_value, dtype=x_like.dtype))), "full_like() array did not equal the fill value"
179185

180186

181187
@given(scalars(shared_dtypes, finite=True),

0 commit comments

Comments
 (0)