Skip to content

Commit 3f6e330

Browse files
committed
Fixed test_full_like
1 parent 8b11476 commit 3f6e330

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
from operator import mul
33
from math import sqrt
44

5+
from hypothesis import assume
56
from hypothesis.strategies import (lists, integers, sampled_from,
67
shared, floats, just, composite, one_of,
78
none, booleans)
89
from hypothesis.extra.array_api import make_strategies_namespace
9-
from hypothesis import assume
1010

1111
from .pytest_helpers import nargs
1212
from .array_helpers import (dtype_ranges, integer_dtype_objects,

array_api_tests/meta_tests/test_hypothesis_helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from math import prod
22

33
import pytest
4-
from hypothesis import given, strategies as st, assume
4+
from hypothesis import given, strategies as st
55

66
from .. import _array_module as xp
77
from .._array_module import _UndefinedStub
@@ -70,4 +70,3 @@ def run(kw):
7070
c_results = [kw for kw in results if "c" in kw]
7171
assert len(c_results) > 0
7272
assert all(isinstance(kw["c"], str) for kw in c_results)
73-

array_api_tests/test_creation_functions.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
scalars, xps, kwargs)
99

1010
from hypothesis import assume, given
11-
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared
11+
from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite
1212

1313

1414
optional_dtypes = none() | shared_dtypes
@@ -148,10 +148,17 @@ def test_full(shape, fill_value, dtype):
148148
assert all(equal(a, asarray(fill_value, **kwargs))), "full() array did not equal the fill value"
149149

150150

151+
@composite
152+
def fill_values(draw):
153+
kw = draw(shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw"))
154+
dtype = kw.get("dtype", None) or draw(shared_dtypes)
155+
return draw(xps.from_dtype(dtype))
156+
157+
151158
@given(
152159
x=xps.arrays(dtype=shared_dtypes, shape=shapes),
153-
fill_value=shared_dtypes.flatmap(xps.from_dtype),
154-
kw=kwargs(dtype=none() | shared_dtypes),
160+
fill_value=fill_values(),
161+
kw=shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw"),
155162
)
156163
def test_full_like(x, fill_value, kw):
157164
out = full_like(x, fill_value, **kw)

0 commit comments

Comments
 (0)