Skip to content

Commit 37fe094

Browse files
xadupresdpython
andauthored
Implements ArrayAPI (#17)
* rename ArrayApi into BaseArrayApi * Implements ArrayAPI * documentation * ci * fix ci * xi * ci * ci * ci * ci * many changes to follow the Array API * more changes * fix unit test * fix ci * ci * api * ci * improvments * refactorign * fix asarray * new udpates * fix two bugs * Add one unit test for empty input * fix all when shape is empty and has one dimension * fix missing return * remove the full tests --------- Co-authored-by: xavier dupré <[email protected]>
1 parent 19f6f8b commit 37fe094

28 files changed

+767
-148
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ Change Logs
44
0.2.0
55
+++++
66

7-
* :pr:`3`: fixes Array API with onnxruntime
7+
* :pr:`17`: implements ArrayAPI
8+
* :pr:`3`: fixes Array API with onnxruntime and scikit-learn

_doc/api/array_api.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
onnx_array_api.array_api
2+
========================
3+
4+
.. toctree::
5+
6+
array_api_onnx_numpy
7+
array_api_onnx_ort

_doc/api/array_api_numpy.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
onnx_array_api.array_api.onnx_numpy
2+
=============================================
3+
4+
.. automodule:: onnx_array_api.array_api.onnx_numpy
5+
:members:

_doc/api/array_api_ort.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
onnx_array_api.array_api.onnx_ort
2+
=================================
3+
4+
.. automodule:: onnx_array_api.array_api.onnx_ort
5+
:members:

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ API
66
.. toctree::
77
:maxdepth: 1
88

9+
array_api
910
npx_functions
1011
npx_var
1112
npx_jit

_doc/api/npx_annot.rst

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,54 @@
1+
=============
12
npx.npx_types
23
=============
34

5+
DType
6+
=====
7+
8+
.. autoclass:: onnx_array_api.npx.npx_types.DType
9+
:members:
10+
411
Annotations
5-
+++++++++++
12+
===========
13+
14+
ElemType
15+
++++++++
616

717
.. autoclass:: onnx_array_api.npx.npx_types.ElemType
818
:members:
919

20+
ParType
21+
+++++++
22+
1023
.. autoclass:: onnx_array_api.npx.npx_types.ParType
1124
:members:
1225

26+
OptParType
27+
++++++++++
28+
1329
.. autoclass:: onnx_array_api.npx.npx_types.OptParType
1430
:members:
1531

32+
TensorType
33+
++++++++++
34+
1635
.. autoclass:: onnx_array_api.npx.npx_types.TensorType
1736
:members:
1837

38+
SequenceType
39+
++++++++++++
40+
1941
.. autoclass:: onnx_array_api.npx.npx_types.SequenceType
2042
:members:
2143

44+
TupleType
45+
+++++++++
46+
2247
.. autoclass:: onnx_array_api.npx.npx_types.TupleType
2348
:members:
2449

2550
Shortcuts
26-
+++++++++
51+
=========
2752

2853
.. autoclass:: onnx_array_api.npx.npx_types.Bool
2954

_unittests/test_array_api.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
import numpy as np
3+
from onnx_array_api.ext_test_case import ExtTestCase
4+
from onnx_array_api.array_api import onnx_numpy as xp
5+
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
6+
7+
8+
class TestOnnxNumpy(ExtTestCase):
9+
def test_abs(self):
10+
c = EagerNumpyTensor(np.array([4, 5], dtype=np.int64))
11+
mat = xp.zeros(c, dtype=xp.int64)
12+
matnp = mat.numpy()
13+
self.assertEqual(matnp.shape, (4, 5))
14+
self.assertNotEmpty(matnp[0, 0])
15+
a = xp.absolute(mat)
16+
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
17+
18+
19+
if __name__ == "__main__":
20+
unittest.main(verbosity=2)

_unittests/ut_npx/test_npx.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
npxapi_inline,
3030
)
3131
from onnx_array_api.npx.npx_functions import absolute as absolute_inline
32+
from onnx_array_api.npx.npx_functions import all as all_inline
3233
from onnx_array_api.npx.npx_functions import arange as arange_inline
3334
from onnx_array_api.npx.npx_functions import arccos as arccos_inline
3435
from onnx_array_api.npx.npx_functions import arccosh as arccosh_inline
@@ -50,6 +51,7 @@
5051
from onnx_array_api.npx.npx_functions import det as det_inline
5152
from onnx_array_api.npx.npx_functions import dot as dot_inline
5253
from onnx_array_api.npx.npx_functions import einsum as einsum_inline
54+
from onnx_array_api.npx.npx_functions import equal as equal_inline
5355
from onnx_array_api.npx.npx_functions import erf as erf_inline
5456
from onnx_array_api.npx.npx_functions import exp as exp_inline
5557
from onnx_array_api.npx.npx_functions import expand_dims as expand_dims_inline
@@ -95,6 +97,7 @@
9597
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
9698
from onnx_array_api.npx.npx_types import (
9799
Bool,
100+
DType,
98101
Float32,
99102
Float64,
100103
Int64,
@@ -127,18 +130,25 @@ def test_tensor(self):
127130
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
128131
self.assertEmpty(dt.shape)
129132
self.assertEqual(dt.type_name(), "TensorType['float32']")
133+
130134
dt = TensorType["float32"]
131135
self.assertEqual(len(dt.dtypes), 1)
132136
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
133137
self.assertEqual(dt.type_name(), "TensorType['float32']")
138+
134139
dt = TensorType[np.float32]
135140
self.assertEqual(len(dt.dtypes), 1)
136141
self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
137142
self.assertEqual(dt.type_name(), "TensorType['float32']")
138143
self.assertEmpty(dt.shape)
139144

145+
dt = TensorType[np.str_]
146+
self.assertEqual(len(dt.dtypes), 1)
147+
self.assertEqual(dt.dtypes[0].dtype, ElemType.str_)
148+
self.assertEqual(dt.type_name(), "TensorType[strings]")
149+
self.assertEmpty(dt.shape)
150+
140151
self.assertRaise(lambda: TensorType[None], TypeError)
141-
self.assertRaise(lambda: TensorType[np.str_], TypeError)
142152
self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError)
143153

144154
def test_superset(self):
@@ -1155,6 +1165,16 @@ def test_astype(self):
11551165
got = ref.run(None, {"A": x})
11561166
self.assertEqualArray(z, got[0])
11571167

1168+
def test_astype_dtype(self):
1169+
f = absolute_inline(copy_inline(Input("A")).astype(DType(7)))
1170+
self.assertIsInstance(f, Var)
1171+
onx = f.to_onnx(constraints={"A": Float64[None]})
1172+
x = np.array([[-5.4, 6.6]], dtype=np.float64)
1173+
z = np.abs(x.astype(np.int64))
1174+
ref = ReferenceEvaluator(onx)
1175+
got = ref.run(None, {"A": x})
1176+
self.assertEqualArray(z, got[0])
1177+
11581178
def test_astype_int(self):
11591179
f = absolute_inline(copy_inline(Input("A")).astype(1))
11601180
self.assertIsInstance(f, Var)
@@ -1413,6 +1433,9 @@ def test_einsum(self):
14131433
lambda x, y: np.einsum(equation, x, y),
14141434
)
14151435

1436+
def test_equal(self):
1437+
self.common_test_inline_bin(equal_inline, np.equal)
1438+
14161439
@unittest.skipIf(scipy is None, reason="scipy is not installed.")
14171440
def test_erf(self):
14181441
self.common_test_inline(erf_inline, scipy.special.erf)
@@ -1460,7 +1483,17 @@ def test_hstack(self):
14601483
def test_identity(self):
14611484
f = identity_inline(2, dtype=np.float64)
14621485
onx = f.to_onnx(constraints={(0, False): Float64[None]})
1463-
z = np.identity(2)
1486+
self.assertIn('name: "dtype"', str(onx))
1487+
z = np.identity(2).astype(np.float64)
1488+
ref = ReferenceEvaluator(onx)
1489+
got = ref.run(None, {})
1490+
self.assertEqualArray(z, got[0])
1491+
1492+
def test_identity_uint8(self):
1493+
f = identity_inline(2, dtype=np.uint8)
1494+
onx = f.to_onnx(constraints={(0, False): Float64[None]})
1495+
self.assertIn('name: "dtype"', str(onx))
1496+
z = np.identity(2).astype(np.uint8)
14641497
ref = ReferenceEvaluator(onx)
14651498
got = ref.run(None, {})
14661499
self.assertEqualArray(z, got[0])
@@ -2318,7 +2351,7 @@ def compute_labels(X, centers):
23182351
self.assertEqual(f.n_versions, 1)
23192352
self.assertEqual(len(f.available_versions), 1)
23202353
self.assertEqual(f.available_versions, [((np.float64, 2), (np.float64, 2))])
2321-
key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2))
2354+
key = ((DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2))
23222355
onx = f.get_onnx(key)
23232356
self.assertIsInstance(onx, ModelProto)
23242357
self.assertRaise(lambda: f.get_onnx(2), ValueError)
@@ -2379,7 +2412,12 @@ def compute_labels(X, centers, use_sqrt=False):
23792412
self.assertEqualArray(got[1], dist)
23802413
self.assertEqual(f.n_versions, 1)
23812414
self.assertEqual(len(f.available_versions), 1)
2382-
key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2), "use_sqrt", True)
2415+
key = (
2416+
(DType(TensorProto.DOUBLE), 2),
2417+
(DType(TensorProto.DOUBLE), 2),
2418+
"use_sqrt",
2419+
True,
2420+
)
23832421
self.assertEqual(f.available_versions, [key])
23842422
onx = f.get_onnx(key)
23852423
self.assertIsInstance(onx, ModelProto)
@@ -2452,7 +2490,52 @@ def test_take(self):
24522490
got = ref.run(None, {"A": data, "B": indices})
24532491
self.assertEqualArray(y, got[0])
24542492

2493+
def test_numpy_all(self):
2494+
data = np.array([[1, 0], [1, 1]]).astype(np.bool_)
2495+
y = np.all(data, axis=1)
2496+
2497+
f = all_inline(Input("A"), axis=1)
2498+
self.assertIsInstance(f, Var)
2499+
onx = f.to_onnx(constraints={"A": Bool[None]})
2500+
ref = ReferenceEvaluator(onx)
2501+
got = ref.run(None, {"A": data})
2502+
self.assertEqualArray(y, got[0])
2503+
2504+
def test_numpy_all_empty(self):
2505+
data = np.zeros((0,), dtype=np.bool_)
2506+
y = np.all(data)
2507+
2508+
f = all_inline(Input("A"))
2509+
self.assertIsInstance(f, Var)
2510+
onx = f.to_onnx(constraints={"A": Bool[None]})
2511+
ref = ReferenceEvaluator(onx)
2512+
got = ref.run(None, {"A": data})
2513+
self.assertEqualArray(y, got[0])
2514+
2515+
@unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0")
2516+
def test_numpy_all_empty_axis_0(self):
2517+
data = np.zeros((0, 1), dtype=np.bool_)
2518+
y = np.all(data, axis=0)
2519+
2520+
f = all_inline(Input("A"), axis=0)
2521+
self.assertIsInstance(f, Var)
2522+
onx = f.to_onnx(constraints={"A": Bool[None]})
2523+
ref = ReferenceEvaluator(onx)
2524+
got = ref.run(None, {"A": data})
2525+
self.assertEqualArray(y, got[0])
2526+
2527+
def test_numpy_all_empty_axis_1(self):
2528+
data = np.zeros((0, 1), dtype=np.bool_)
2529+
y = np.all(data, axis=1)
2530+
2531+
f = all_inline(Input("A"), axis=1)
2532+
self.assertIsInstance(f, Var)
2533+
onx = f.to_onnx(constraints={"A": Bool[None]})
2534+
ref = ReferenceEvaluator(onx)
2535+
got = ref.run(None, {"A": data})
2536+
self.assertEqualArray(y, got[0])
2537+
24552538

24562539
if __name__ == "__main__":
2457-
TestNpx().test_take()
2540+
# TestNpx().test_numpy_all_empty_axis_0()
24582541
unittest.main(verbosity=2)

_unittests/ut_npx/test_sklearn_array_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from onnx.defs import onnx_opset_version
55
from sklearn import config_context, __version__ as sklearn_version
66
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
7-
from onnx_array_api.ext_test_case import ExtTestCase
7+
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
88
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
99

1010

@@ -16,6 +16,7 @@ class TestSklearnArrayAPI(ExtTestCase):
1616
Version(sklearn_version) <= Version("1.2.2"),
1717
reason="reshape ArrayAPI not followed",
1818
)
19+
@ignore_warnings(DeprecationWarning)
1920
def test_sklearn_array_api_linear_discriminant(self):
2021
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
2122
y = np.array([1, 1, 1, 2, 2, 2])
@@ -26,6 +27,8 @@ def test_sklearn_array_api_linear_discriminant(self):
2627
new_x = EagerNumpyTensor(X)
2728
self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x))
2829
with config_context(array_api_dispatch=True):
30+
# It fails if scikit-learn <= 1.2.2 because the ArrayAPI
31+
# is not strictly applied.
2932
got = ana.predict(new_x)
3033
self.assertEqualArray(expected, got.numpy())
3134

0 commit comments

Comments
 (0)