Skip to content

Commit a5df654

Browse files
committed
Merge branch 'main' into test_all
2 parents 1f135c9 + f7bd970 commit a5df654

File tree

4 files changed

+24
-14
lines changed

4 files changed

+24
-14
lines changed

.github/workflows/docs-deploy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
steps:
1414
- uses: actions/checkout@v4
1515
- name: Download Artifact
16-
uses: dawidd6/action-download-artifact@v9
16+
uses: dawidd6/action-download-artifact@v10
1717
with:
1818
workflow: docs-build.yml
1919
name: docs-build

.github/workflows/tests.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,20 @@ jobs:
3232
python -m pip install --upgrade pip
3333
python -m pip install pytest
3434
35+
# Don't `pip install .[dev]` as it would pull in the whole torch cuda stack
36+
python -m pip install array-api-strict
37+
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
38+
3539
if [ "${{ matrix.numpy-version }}" == "dev" ]; then
3640
python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
41+
python -m pip install dask[array] jax[cpu] sparse ndonnx
3742
elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then
3843
python -m pip install 'numpy==1.22.*'
3944
elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then
4045
python -m pip install 'numpy==1.26.*'
4146
else
42-
# Don't `pip install .[dev]` as it would pull in the whole torch cuda stack
43-
python -m pip install array-api-strict dask[array] jax[cpu] numpy sparse
44-
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
45-
if [ "${{ matrix.python-version }}" != "3.13" ]; then
46-
# onnx wheels are not available on Python 3.13 at the moment of writing
47-
python -m pip install ndonnx
48-
fi
47+
python -m pip install numpy
48+
python -m pip install dask[array] jax[cpu] sparse ndonnx
4949
fi
5050
5151
- name: Dump pip environment

tests/test_array_namespace.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def test_array_namespace(library, api_version, use_compat):
2323
if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
2424
pytest.skip("Unsupported API version")
2525

26-
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)
26+
with warnings.catch_warnings():
27+
warnings.simplefilter('ignore', UserWarning)
28+
namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)
2729

2830
if use_compat is False or use_compat is None and library not in wrapped_libraries:
2931
if library == "jax.numpy" and use_compat is None:
@@ -45,10 +47,13 @@ def test_array_namespace(library, api_version, use_compat):
4547

4648
if library == "numpy":
4749
# check that the same namespace is returned for NumPy scalars
48-
scalar_namespace = array_namespace(
49-
xp.float64(0.0), api_version=api_version, use_compat=use_compat
50-
)
51-
assert scalar_namespace == namespace
50+
with warnings.catch_warnings():
51+
warnings.simplefilter('ignore', UserWarning)
52+
53+
scalar_namespace = array_namespace(
54+
xp.float64(0.0), api_version=api_version, use_compat=use_compat
55+
)
56+
assert scalar_namespace == namespace
5257

5358
# Check that array_namespace works even if jax.experimental.array_api
5459
# hasn't been imported yet (it monkeypatches __array_namespace__
@@ -97,7 +102,9 @@ def test_api_version_torch():
97102
torch = import_("torch")
98103
x = torch.asarray([1, 2])
99104
torch_ = import_("torch", wrapper=True)
100-
assert array_namespace(x, api_version="2023.12") == torch_
105+
with warnings.catch_warnings():
106+
warnings.simplefilter('ignore', UserWarning)
107+
assert array_namespace(x, api_version="2023.12") == torch_
101108
assert array_namespace(x, api_version=None) == torch_
102109
assert array_namespace(x) == torch_
103110
# Should issue a warning

tests/test_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def test_device_to_device(library, request):
195195
xfail(request, reason="Stub raises ValueError")
196196
if library == "sparse":
197197
xfail(request, reason="No __array_namespace_info__()")
198+
if library == "array_api_strict":
199+
if np.__version__ < "2":
200+
xfail(request, reason="no copy argument of np.asarray")
198201

199202
xp = import_(library, wrapper=True)
200203
devices = xp.__array_namespace_info__().devices()

0 commit comments

Comments
 (0)