|
29 | 29 | npxapi_inline,
|
30 | 30 | )
|
31 | 31 | from onnx_array_api.npx.npx_functions import absolute as absolute_inline
|
| 32 | +from onnx_array_api.npx.npx_functions import all as all_inline |
32 | 33 | from onnx_array_api.npx.npx_functions import arange as arange_inline
|
33 | 34 | from onnx_array_api.npx.npx_functions import arccos as arccos_inline
|
34 | 35 | from onnx_array_api.npx.npx_functions import arccosh as arccosh_inline
|
|
50 | 51 | from onnx_array_api.npx.npx_functions import det as det_inline
|
51 | 52 | from onnx_array_api.npx.npx_functions import dot as dot_inline
|
52 | 53 | from onnx_array_api.npx.npx_functions import einsum as einsum_inline
|
| 54 | +from onnx_array_api.npx.npx_functions import equal as equal_inline |
53 | 55 | from onnx_array_api.npx.npx_functions import erf as erf_inline
|
54 | 56 | from onnx_array_api.npx.npx_functions import exp as exp_inline
|
55 | 57 | from onnx_array_api.npx.npx_functions import expand_dims as expand_dims_inline
|
|
95 | 97 | from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
|
96 | 98 | from onnx_array_api.npx.npx_types import (
|
97 | 99 | Bool,
|
| 100 | + DType, |
98 | 101 | Float32,
|
99 | 102 | Float64,
|
100 | 103 | Int64,
|
@@ -127,18 +130,25 @@ def test_tensor(self):
|
127 | 130 | self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
|
128 | 131 | self.assertEmpty(dt.shape)
|
129 | 132 | self.assertEqual(dt.type_name(), "TensorType['float32']")
|
| 133 | + |
130 | 134 | dt = TensorType["float32"]
|
131 | 135 | self.assertEqual(len(dt.dtypes), 1)
|
132 | 136 | self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
|
133 | 137 | self.assertEqual(dt.type_name(), "TensorType['float32']")
|
| 138 | + |
134 | 139 | dt = TensorType[np.float32]
|
135 | 140 | self.assertEqual(len(dt.dtypes), 1)
|
136 | 141 | self.assertEqual(dt.dtypes[0].dtype, ElemType.float32)
|
137 | 142 | self.assertEqual(dt.type_name(), "TensorType['float32']")
|
138 | 143 | self.assertEmpty(dt.shape)
|
139 | 144 |
|
| 145 | + dt = TensorType[np.str_] |
| 146 | + self.assertEqual(len(dt.dtypes), 1) |
| 147 | + self.assertEqual(dt.dtypes[0].dtype, ElemType.str_) |
| 148 | + self.assertEqual(dt.type_name(), "TensorType[strings]") |
| 149 | + self.assertEmpty(dt.shape) |
| 150 | + |
140 | 151 | self.assertRaise(lambda: TensorType[None], TypeError)
|
141 |
| - self.assertRaise(lambda: TensorType[np.str_], TypeError) |
142 | 152 | self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError)
|
143 | 153 |
|
144 | 154 | def test_superset(self):
|
@@ -1155,6 +1165,16 @@ def test_astype(self):
|
1155 | 1165 | got = ref.run(None, {"A": x})
|
1156 | 1166 | self.assertEqualArray(z, got[0])
|
1157 | 1167 |
|
| 1168 | + def test_astype_dtype(self): |
| 1169 | + f = absolute_inline(copy_inline(Input("A")).astype(DType(7))) |
| 1170 | + self.assertIsInstance(f, Var) |
| 1171 | + onx = f.to_onnx(constraints={"A": Float64[None]}) |
| 1172 | + x = np.array([[-5.4, 6.6]], dtype=np.float64) |
| 1173 | + z = np.abs(x.astype(np.int64)) |
| 1174 | + ref = ReferenceEvaluator(onx) |
| 1175 | + got = ref.run(None, {"A": x}) |
| 1176 | + self.assertEqualArray(z, got[0]) |
| 1177 | + |
1158 | 1178 | def test_astype_int(self):
|
1159 | 1179 | f = absolute_inline(copy_inline(Input("A")).astype(1))
|
1160 | 1180 | self.assertIsInstance(f, Var)
|
@@ -1413,6 +1433,9 @@ def test_einsum(self):
|
1413 | 1433 | lambda x, y: np.einsum(equation, x, y),
|
1414 | 1434 | )
|
1415 | 1435 |
|
| 1436 | + def test_equal(self): |
| 1437 | + self.common_test_inline_bin(equal_inline, np.equal) |
| 1438 | + |
1416 | 1439 | @unittest.skipIf(scipy is None, reason="scipy is not installed.")
|
1417 | 1440 | def test_erf(self):
|
1418 | 1441 | self.common_test_inline(erf_inline, scipy.special.erf)
|
@@ -1460,7 +1483,17 @@ def test_hstack(self):
|
1460 | 1483 | def test_identity(self):
|
1461 | 1484 | f = identity_inline(2, dtype=np.float64)
|
1462 | 1485 | onx = f.to_onnx(constraints={(0, False): Float64[None]})
|
1463 |
| - z = np.identity(2) |
| 1486 | + self.assertIn('name: "dtype"', str(onx)) |
| 1487 | + z = np.identity(2).astype(np.float64) |
| 1488 | + ref = ReferenceEvaluator(onx) |
| 1489 | + got = ref.run(None, {}) |
| 1490 | + self.assertEqualArray(z, got[0]) |
| 1491 | + |
| 1492 | + def test_identity_uint8(self): |
| 1493 | + f = identity_inline(2, dtype=np.uint8) |
| 1494 | + onx = f.to_onnx(constraints={(0, False): Float64[None]}) |
| 1495 | + self.assertIn('name: "dtype"', str(onx)) |
| 1496 | + z = np.identity(2).astype(np.uint8) |
1464 | 1497 | ref = ReferenceEvaluator(onx)
|
1465 | 1498 | got = ref.run(None, {})
|
1466 | 1499 | self.assertEqualArray(z, got[0])
|
@@ -2318,7 +2351,7 @@ def compute_labels(X, centers):
|
2318 | 2351 | self.assertEqual(f.n_versions, 1)
|
2319 | 2352 | self.assertEqual(len(f.available_versions), 1)
|
2320 | 2353 | self.assertEqual(f.available_versions, [((np.float64, 2), (np.float64, 2))])
|
2321 |
| - key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2)) |
| 2354 | + key = ((DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2)) |
2322 | 2355 | onx = f.get_onnx(key)
|
2323 | 2356 | self.assertIsInstance(onx, ModelProto)
|
2324 | 2357 | self.assertRaise(lambda: f.get_onnx(2), ValueError)
|
@@ -2379,7 +2412,12 @@ def compute_labels(X, centers, use_sqrt=False):
|
2379 | 2412 | self.assertEqualArray(got[1], dist)
|
2380 | 2413 | self.assertEqual(f.n_versions, 1)
|
2381 | 2414 | self.assertEqual(len(f.available_versions), 1)
|
2382 |
| - key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2), "use_sqrt", True) |
| 2415 | + key = ( |
| 2416 | + (DType(TensorProto.DOUBLE), 2), |
| 2417 | + (DType(TensorProto.DOUBLE), 2), |
| 2418 | + "use_sqrt", |
| 2419 | + True, |
| 2420 | + ) |
2383 | 2421 | self.assertEqual(f.available_versions, [key])
|
2384 | 2422 | onx = f.get_onnx(key)
|
2385 | 2423 | self.assertIsInstance(onx, ModelProto)
|
@@ -2452,7 +2490,52 @@ def test_take(self):
|
2452 | 2490 | got = ref.run(None, {"A": data, "B": indices})
|
2453 | 2491 | self.assertEqualArray(y, got[0])
|
2454 | 2492 |
|
| 2493 | + def test_numpy_all(self): |
| 2494 | + data = np.array([[1, 0], [1, 1]]).astype(np.bool_) |
| 2495 | + y = np.all(data, axis=1) |
| 2496 | + |
| 2497 | + f = all_inline(Input("A"), axis=1) |
| 2498 | + self.assertIsInstance(f, Var) |
| 2499 | + onx = f.to_onnx(constraints={"A": Bool[None]}) |
| 2500 | + ref = ReferenceEvaluator(onx) |
| 2501 | + got = ref.run(None, {"A": data}) |
| 2502 | + self.assertEqualArray(y, got[0]) |
| 2503 | + |
| 2504 | + def test_numpy_all_empty(self): |
| 2505 | + data = np.zeros((0,), dtype=np.bool_) |
| 2506 | + y = np.all(data) |
| 2507 | + |
| 2508 | + f = all_inline(Input("A")) |
| 2509 | + self.assertIsInstance(f, Var) |
| 2510 | + onx = f.to_onnx(constraints={"A": Bool[None]}) |
| 2511 | + ref = ReferenceEvaluator(onx) |
| 2512 | + got = ref.run(None, {"A": data}) |
| 2513 | + self.assertEqualArray(y, got[0]) |
| 2514 | + |
| 2515 | + @unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0") |
| 2516 | + def test_numpy_all_empty_axis_0(self): |
| 2517 | + data = np.zeros((0, 1), dtype=np.bool_) |
| 2518 | + y = np.all(data, axis=0) |
| 2519 | + |
| 2520 | + f = all_inline(Input("A"), axis=0) |
| 2521 | + self.assertIsInstance(f, Var) |
| 2522 | + onx = f.to_onnx(constraints={"A": Bool[None]}) |
| 2523 | + ref = ReferenceEvaluator(onx) |
| 2524 | + got = ref.run(None, {"A": data}) |
| 2525 | + self.assertEqualArray(y, got[0]) |
| 2526 | + |
| 2527 | + def test_numpy_all_empty_axis_1(self): |
| 2528 | + data = np.zeros((0, 1), dtype=np.bool_) |
| 2529 | + y = np.all(data, axis=1) |
| 2530 | + |
| 2531 | + f = all_inline(Input("A"), axis=1) |
| 2532 | + self.assertIsInstance(f, Var) |
| 2533 | + onx = f.to_onnx(constraints={"A": Bool[None]}) |
| 2534 | + ref = ReferenceEvaluator(onx) |
| 2535 | + got = ref.run(None, {"A": data}) |
| 2536 | + self.assertEqualArray(y, got[0]) |
| 2537 | + |
2455 | 2538 |
|
2456 | 2539 | if __name__ == "__main__":
|
2457 |
| - TestNpx().test_take() |
| 2540 | + # TestNpx().test_numpy_all_empty_axis_0() |
2458 | 2541 | unittest.main(verbosity=2)
|
0 commit comments