Skip to content

Commit 2b559e6

Browse files
authored
TYP: Type annotations, part 4 (#313)
* Type annotations, part 4 * Fix CopyMode * revert * Revert `_all_ignore` * code review * code review * JustInt mypy ignores * lint * fix merge * lint * Reverts and tweaks * Fix test_all * Revert batmobile
1 parent 6ae28ee commit 2b559e6

File tree

21 files changed

+246
-261
lines changed

21 files changed

+246
-261
lines changed

array_api_compat/_internal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
4646
specification for more details.
4747
4848
"""
49-
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
50-
return wrapped_f # pyright: ignore[reportReturnType]
49+
wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
50+
return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType]
5151

5252
return inner
5353

array_api_compat/common/_aliases.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from __future__ import annotations
66

77
import inspect
8-
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
8+
from collections.abc import Sequence
9+
from typing import TYPE_CHECKING, Any, NamedTuple, cast
910

1011
from ._helpers import _check_device, array_namespace
1112
from ._helpers import device as _get_device
12-
from ._helpers import is_cupy_namespace as _is_cupy_namespace
13+
from ._helpers import is_cupy_namespace
1314
from ._typing import Array, Device, DType, Namespace
1415

1516
if TYPE_CHECKING:
@@ -381,8 +382,8 @@ def clip(
381382
# TODO: np.clip has other ufunc kwargs
382383
out: Array | None = None,
383384
) -> Array:
384-
def _isscalar(a: object) -> TypeIs[int | float | None]:
385-
return isinstance(a, (int, float, type(None)))
385+
def _isscalar(a: object) -> TypeIs[float | None]:
386+
return isinstance(a, int | float) or a is None
386387

387388
min_shape = () if _isscalar(min) else min.shape
388389
max_shape = () if _isscalar(max) else max.shape
@@ -450,7 +451,7 @@ def reshape(
450451
shape: tuple[int, ...],
451452
xp: Namespace,
452453
*,
453-
copy: Optional[bool] = None,
454+
copy: bool | None = None,
454455
**kwargs: object,
455456
) -> Array:
456457
if copy is True:
@@ -657,7 +658,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
657658
out = xp.sign(x, **kwargs)
658659
# CuPy sign() does not propagate nans. See
659660
# https://github.com/data-apis/array-api-compat/issues/136
660-
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
661+
if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
661662
out[xp.isnan(x)] = xp.nan
662663
return out[()]
663664

array_api_compat/common/_helpers.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,36 @@
2323
SupportsIndex,
2424
TypeAlias,
2525
TypeGuard,
26-
TypeVar,
2726
cast,
2827
overload,
2928
)
3029

3130
from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
3231

3332
if TYPE_CHECKING:
34-
33+
import cupy as cp
3534
import dask.array as da
3635
import jax
3736
import ndonnx as ndx
3837
import numpy as np
3938
import numpy.typing as npt
40-
import sparse # pyright: ignore[reportMissingTypeStubs]
39+
import sparse
4140
import torch
4241

4342
# TODO: import from typing (requires Python >=3.13)
44-
from typing_extensions import TypeIs, TypeVar
45-
46-
_SizeT = TypeVar("_SizeT", bound = int | None)
43+
from typing_extensions import TypeIs
4744

4845
_ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
49-
_CupyArray: TypeAlias = Any # cupy has no py.typed
5046

5147
_ArrayApiObj: TypeAlias = (
5248
npt.NDArray[Any]
49+
| cp.ndarray
5350
| da.Array
5451
| jax.Array
5552
| ndx.Array
5653
| sparse.SparseArray
5754
| torch.Tensor
5855
| SupportsArrayNamespace[Any]
59-
| _CupyArray
6056
)
6157

6258
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
@@ -96,7 +92,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
9692
return dtype == jax.float0
9793

9894

99-
def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
95+
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
10096
"""
10197
Return True if `x` is a NumPy array.
10298
@@ -267,7 +263,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
267263
return _issubclass_fast(cls, "sparse", "SparseArray")
268264

269265

270-
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
266+
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
271267
"""
272268
Return True if `x` is an array API compatible array object.
273269
@@ -748,7 +744,7 @@ def device(x: _ArrayApiObj, /) -> Device:
748744
return "cpu"
749745
elif is_dask_array(x):
750746
# Peek at the metadata of the Dask array to determine type
751-
if is_numpy_array(x._meta): # pyright: ignore
747+
if is_numpy_array(x._meta):
752748
# Must be on CPU since backed by numpy
753749
return "cpu"
754750
return _DASK_DEVICE
@@ -777,7 +773,7 @@ def device(x: _ArrayApiObj, /) -> Device:
777773
return "cpu"
778774
# Return the device of the constituent array
779775
return device(inner) # pyright: ignore
780-
return x.device # pyright: ignore
776+
return x.device # type: ignore # pyright: ignore
781777

782778

783779
# Prevent shadowing, used below
@@ -786,11 +782,11 @@ def device(x: _ArrayApiObj, /) -> Device:
786782

787783
# Based on cupy.array_api.Array.to_device
788784
def _cupy_to_device(
789-
x: _CupyArray,
785+
x: cp.ndarray,
790786
device: Device,
791787
/,
792788
stream: int | Any | None = None,
793-
) -> _CupyArray:
789+
) -> cp.ndarray:
794790
import cupy as cp
795791

796792
if device == "cpu":
@@ -819,7 +815,7 @@ def _torch_to_device(
819815
x: torch.Tensor,
820816
device: torch.device | str | int,
821817
/,
822-
stream: None = None,
818+
stream: int | Any | None = None,
823819
) -> torch.Tensor:
824820
if stream is not None:
825821
raise NotImplementedError
@@ -885,7 +881,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
885881
# cupy does not yet have to_device
886882
return _cupy_to_device(x, device, stream=stream)
887883
elif is_torch_array(x):
888-
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType]
884+
return _torch_to_device(x, device, stream=stream)
889885
elif is_dask_array(x):
890886
if stream is not None:
891887
raise ValueError("The stream argument to to_device() is not supported")
@@ -912,8 +908,6 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
912908
@overload
913909
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
914910
@overload
915-
def size(x: HasShape[Collection[None]]) -> None: ...
916-
@overload
917911
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
918912
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
919913
"""
@@ -948,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
948942
return None
949943

950944

951-
def is_writeable_array(x: object) -> bool:
945+
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
952946
"""
953947
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
954948
Return False if `x` is not an array API compatible object.
@@ -986,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
986980
return None
987981

988982

989-
def is_lazy_array(x: object) -> bool:
983+
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
990984
"""Return True if x is potentially a future or it may be otherwise impossible or
991985
expensive to eagerly read its contents, regardless of their size, e.g. by
992986
calling ``bool(x)`` or ``float(x)``.

array_api_compat/common/_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
if np.__version__[0] == "2":
99
from numpy.lib.array_utils import normalize_axis_tuple
1010
else:
11-
from numpy.core.numeric import normalize_axis_tuple
11+
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]
1212

1313
from .._internal import get_xp
1414
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
@@ -187,14 +187,14 @@ def vector_norm(
187187
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
188188
# above to avoid matrix norm logic.
189189
shape = list(x.shape)
190-
_axis = cast(
190+
axes = cast(
191191
"tuple[int, ...]",
192192
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
193193
range(x.ndim) if axis is None else axis,
194194
x.ndim,
195195
),
196196
)
197-
for i in _axis:
197+
for i in axes:
198198
shape[i] = 1
199199
res = xp.reshape(res, tuple(shape))
200200

array_api_compat/common/_typing.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,29 @@
3434
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
3535
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
3636
@final
37-
class JustInt(Protocol):
38-
@property
37+
class JustInt(Protocol): # type: ignore[misc]
38+
@property # type: ignore[override]
3939
def __class__(self, /) -> type[int]: ...
4040
@__class__.setter
4141
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
4242

4343

4444
@final
45-
class JustFloat(Protocol):
46-
@property
45+
class JustFloat(Protocol): # type: ignore[misc]
46+
@property # type: ignore[override]
4747
def __class__(self, /) -> type[float]: ...
4848
@__class__.setter
4949
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
5050

5151

5252
@final
53-
class JustComplex(Protocol):
54-
@property
53+
class JustComplex(Protocol): # type: ignore[misc]
54+
@property # type: ignore[override]
5555
def __class__(self, /) -> type[complex]: ...
5656
@__class__.setter
5757
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
5858

5959

60-
#
61-
62-
6360
class NestedSequence(Protocol[_T_co]):
6461
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
6562
def __len__(self, /) -> int: ...

array_api_compat/cupy/_aliases.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Optional
3+
from builtins import bool as py_bool
44

55
import cupy as cp
66

@@ -67,18 +67,13 @@
6767

6868
# asarray also adds the copy keyword, which is not present in numpy 1.0.
6969
def asarray(
70-
obj: (
71-
Array
72-
| bool | int | float | complex
73-
| NestedSequence[bool | int | float | complex]
74-
| SupportsBufferProtocol
75-
),
70+
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
7671
/,
7772
*,
78-
dtype: Optional[DType] = None,
79-
device: Optional[Device] = None,
80-
copy: Optional[bool] = None,
81-
**kwargs,
73+
dtype: DType | None = None,
74+
device: Device | None = None,
75+
copy: py_bool | None = None,
76+
**kwargs: object,
8277
) -> Array:
8378
"""
8479
Array API compatibility wrapper for asarray().
@@ -101,8 +96,8 @@ def astype(
10196
dtype: DType,
10297
/,
10398
*,
104-
copy: bool = True,
105-
device: Optional[Device] = None,
99+
copy: py_bool = True,
100+
device: Device | None = None,
106101
) -> Array:
107102
if device is None:
108103
return x.astype(dtype=dtype, copy=copy)
@@ -113,8 +108,8 @@ def astype(
113108
# cupy.count_nonzero does not have keepdims
114109
def count_nonzero(
115110
x: Array,
116-
axis=None,
117-
keepdims=False
111+
axis: int | tuple[int, ...] | None = None,
112+
keepdims: py_bool = False,
118113
) -> Array:
119114
result = cp.count_nonzero(x, axis)
120115
if keepdims:
@@ -125,7 +120,7 @@ def count_nonzero(
125120

126121

127122
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
128-
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
123+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
129124
return cp.take_along_axis(x, indices, axis=axis)
130125

131126

@@ -153,4 +148,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
153148
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
154149
'take_along_axis']
155150

156-
_all_ignore = ['cp', 'get_xp']
151+
152+
def __dir__() -> list[str]:
153+
return __all__

array_api_compat/cupy/fft.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from cupy.fft import * # noqa: F403
1+
from cupy.fft import * # noqa: F403
2+
23
# cupy.fft doesn't have __all__. If it is added, replace this with
34
#
45
# from cupy.fft import __all__ as linalg_all
5-
_n = {}
6-
exec('from cupy.fft import *', _n)
7-
del _n['__builtins__']
6+
_n: dict[str, object] = {}
7+
exec("from cupy.fft import *", _n)
8+
del _n["__builtins__"]
89
fft_all = list(_n)
910
del _n
1011

array_api_compat/cupy/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# cupy.linalg doesn't have __all__. If it is added, replace this with
33
#
44
# from cupy.linalg import __all__ as linalg_all
5-
_n = {}
5+
_n: dict[str, object] = {}
66
exec('from cupy.linalg import *', _n)
77
del _n['__builtins__']
88
linalg_all = list(_n)

array_api_compat/dask/array/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dask.array import * # noqa: F403
44

55
# These imports may overwrite names from the import * above.
6-
from ._aliases import * # noqa: F403
6+
from ._aliases import * # type: ignore[assignment] # noqa: F403
77

88
__array_api_version__: Final = "2024.12"
99

array_api_compat/dask/array/_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def arange(
146146

147147
# asarray also adds the copy keyword, which is not present in numpy 1.0.
148148
def asarray(
149-
obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
149+
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
150150
/,
151151
*,
152152
dtype: DType | None = None,

0 commit comments

Comments
 (0)