Skip to content

Commit 517d0c8

Browse files
committed
cosmetic
1 parent 5158082 commit 517d0c8

File tree

1 file changed

+37
-24
lines changed

1 file changed

+37
-24
lines changed

array_api_compat/dask/array/_aliases.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from ...common import _aliases
4-
from ...common._helpers import _check_device
54

65
from ..._internal import get_xp
76

@@ -40,19 +39,25 @@
4039
isdtype = get_xp(np)(_aliases.isdtype)
4140
unstack = get_xp(da)(_aliases.unstack)
4241

42+
# da.astype doesn't respect copy=True
4343
def astype(
4444
x: Array,
4545
dtype: Dtype,
4646
/,
4747
*,
4848
copy: bool = True,
49-
device: Device | None = None
49+
device: Optional[Device] = None
5050
) -> Array:
51+
"""
52+
Array API compatibility wrapper for astype().
53+
54+
See the corresponding documentation in the array library and/or the array API
55+
specification for more details.
56+
"""
5157
# TODO: respect device keyword?
58+
5259
if not copy and dtype == x.dtype:
5360
return x
54-
# dask astype doesn't respect copy=True,
55-
# so call copy manually afterwards
5661
x = x.astype(dtype)
5762
return x.copy() if copy else x
5863

@@ -61,20 +66,24 @@ def astype(
6166
# This arange func is modified from the common one to
6267
# not pass stop/step as keyword arguments, which will cause
6368
# an error with dask
64-
65-
# TODO: delete the xp stuff, it shouldn't be necessary
66-
def _dask_arange(
69+
def arange(
6770
start: Union[int, float],
6871
/,
6972
stop: Optional[Union[int, float]] = None,
7073
step: Union[int, float] = 1,
7174
*,
72-
xp,
7375
dtype: Optional[Dtype] = None,
7476
device: Optional[Device] = None,
7577
**kwargs,
7678
) -> Array:
77-
_check_device(xp, device)
79+
"""
80+
Array API compatibility wrapper for arange().
81+
82+
See the corresponding documentation in the array library and/or the array API
83+
specification for more details.
84+
"""
85+
# TODO: respect device keyword?
86+
7887
args = [start]
7988
if stop is not None:
8089
args.append(stop)
@@ -83,13 +92,12 @@ def _dask_arange(
8392
# prepend the default value for start which is 0
8493
args.insert(0, 0)
8594
args.append(step)
86-
return xp.arange(*args, dtype=dtype, **kwargs)
8795

88-
arange = get_xp(da)(_dask_arange)
89-
eye = get_xp(da)(_aliases.eye)
96+
return da.arange(*args, dtype=dtype, **kwargs)
97+
9098

91-
linspace = get_xp(da)(_aliases.linspace)
9299
eye = get_xp(da)(_aliases.eye)
100+
linspace = get_xp(da)(_aliases.linspace)
93101
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
94102
UniqueCountsResult = get_xp(da)(_aliases.UniqueCountsResult)
95103
UniqueInverseResult = get_xp(da)(_aliases.UniqueInverseResult)
@@ -112,7 +120,6 @@ def _dask_arange(
112120
reshape = get_xp(da)(_aliases.reshape)
113121
matrix_transpose = get_xp(da)(_aliases.matrix_transpose)
114122
vecdot = get_xp(da)(_aliases.vecdot)
115-
116123
nonzero = get_xp(da)(_aliases.nonzero)
117124
ceil = get_xp(np)(_aliases.ceil)
118125
floor = get_xp(np)(_aliases.floor)
@@ -121,6 +128,7 @@ def _dask_arange(
121128
tensordot = get_xp(np)(_aliases.tensordot)
122129
sign = get_xp(np)(_aliases.sign)
123130

131+
124132
# asarray also adds the copy keyword, which is not present in numpy 1.0.
125133
def asarray(
126134
obj: Union[
@@ -135,7 +143,7 @@ def asarray(
135143
*,
136144
dtype: Optional[Dtype] = None,
137145
device: Optional[Device] = None,
138-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
146+
copy: Optional[Union[bool, np._CopyMode]] = None,
139147
**kwargs,
140148
) -> Array:
141149
"""
@@ -144,6 +152,8 @@ def asarray(
144152
See the corresponding documentation in the array library and/or the array API
145153
specification for more details.
146154
"""
155+
# TODO: respect device keyword?
156+
147157
if isinstance(obj, da.Array):
148158
if dtype is not None and dtype != obj.dtype:
149159
if copy is False:
@@ -183,15 +193,18 @@ def asarray(
183193
# Furthermore, the masking workaround in common._aliases.clip cannot work with
184194
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
185195
# now).
186-
@get_xp(da)
187196
def clip(
188197
x: Array,
189198
/,
190199
min: Optional[Union[int, float, Array]] = None,
191200
max: Optional[Union[int, float, Array]] = None,
192-
*,
193-
xp,
194201
) -> Array:
202+
"""
203+
Array API compatibility wrapper for clip().
204+
205+
See the corresponding documentation in the array library and/or the array API
206+
specification for more details.
207+
"""
195208
def _isscalar(a):
196209
return isinstance(a, (int, float, type(None)))
197210
min_shape = () if _isscalar(min) else min.shape
@@ -201,19 +214,19 @@ def _isscalar(a):
201214
result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape)
202215

203216
if min is not None:
204-
min = xp.broadcast_to(xp.asarray(min), result_shape)
217+
min = da.broadcast_to(da.asarray(min), result_shape)
205218
if max is not None:
206-
max = xp.broadcast_to(xp.asarray(max), result_shape)
219+
max = da.broadcast_to(da.asarray(max), result_shape)
207220

208221
if min is None and max is None:
209-
return xp.positive(x)
222+
return da.positive(x)
210223

211224
if min is None:
212-
return astype(xp.minimum(x, max), x.dtype)
225+
return astype(da.minimum(x, max), x.dtype)
213226
if max is None:
214-
return astype(xp.maximum(x, min), x.dtype)
227+
return astype(da.maximum(x, min), x.dtype)
215228

216-
return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)
229+
return astype(da.minimum(da.maximum(x, min), max), x.dtype)
217230

218231
# exclude these from all since dask.array has no sorting functions
219232
_da_unsupported = ['sort', 'argsort']

0 commit comments

Comments
 (0)