1
1
from __future__ import annotations
2
2
3
3
from ...common import _aliases
4
- from ...common ._helpers import _check_device
5
4
6
5
from ..._internal import get_xp
7
6
40
39
isdtype = get_xp (np )(_aliases .isdtype )
41
40
unstack = get_xp (da )(_aliases .unstack )
42
41
42
+ # da.astype doesn't respect copy=True
43
43
def astype (
44
44
x : Array ,
45
45
dtype : Dtype ,
46
46
/ ,
47
47
* ,
48
48
copy : bool = True ,
49
- device : Device | None = None
49
+ device : Optional [ Device ] = None
50
50
) -> Array :
51
+ """
52
+ Array API compatibility wrapper for astype().
53
+
54
+ See the corresponding documentation in the array library and/or the array API
55
+ specification for more details.
56
+ """
51
57
# TODO: respect device keyword?
58
+
52
59
if not copy and dtype == x .dtype :
53
60
return x
54
- # dask astype doesn't respect copy=True,
55
- # so call copy manually afterwards
56
61
x = x .astype (dtype )
57
62
return x .copy () if copy else x
58
63
@@ -61,20 +66,24 @@ def astype(
61
66
# This arange func is modified from the common one to
62
67
# not pass stop/step as keyword arguments, which will cause
63
68
# an error with dask
64
-
65
- # TODO: delete the xp stuff, it shouldn't be necessary
66
- def _dask_arange (
69
+ def arange (
67
70
start : Union [int , float ],
68
71
/ ,
69
72
stop : Optional [Union [int , float ]] = None ,
70
73
step : Union [int , float ] = 1 ,
71
74
* ,
72
- xp ,
73
75
dtype : Optional [Dtype ] = None ,
74
76
device : Optional [Device ] = None ,
75
77
** kwargs ,
76
78
) -> Array :
77
- _check_device (xp , device )
79
+ """
80
+ Array API compatibility wrapper for arange().
81
+
82
+ See the corresponding documentation in the array library and/or the array API
83
+ specification for more details.
84
+ """
85
+ # TODO: respect device keyword?
86
+
78
87
args = [start ]
79
88
if stop is not None :
80
89
args .append (stop )
@@ -83,13 +92,12 @@ def _dask_arange(
83
92
# prepend the default value for start which is 0
84
93
args .insert (0 , 0 )
85
94
args .append (step )
86
- return xp .arange (* args , dtype = dtype , ** kwargs )
87
95
88
- arange = get_xp ( da )( _dask_arange )
89
- eye = get_xp ( da )( _aliases . eye )
96
+ return da . arange ( * args , dtype = dtype , ** kwargs )
97
+
90
98
91
- linspace = get_xp (da )(_aliases .linspace )
92
99
eye = get_xp (da )(_aliases .eye )
100
+ linspace = get_xp (da )(_aliases .linspace )
93
101
UniqueAllResult = get_xp (da )(_aliases .UniqueAllResult )
94
102
UniqueCountsResult = get_xp (da )(_aliases .UniqueCountsResult )
95
103
UniqueInverseResult = get_xp (da )(_aliases .UniqueInverseResult )
@@ -112,7 +120,6 @@ def _dask_arange(
112
120
reshape = get_xp (da )(_aliases .reshape )
113
121
matrix_transpose = get_xp (da )(_aliases .matrix_transpose )
114
122
vecdot = get_xp (da )(_aliases .vecdot )
115
-
116
123
nonzero = get_xp (da )(_aliases .nonzero )
117
124
ceil = get_xp (np )(_aliases .ceil )
118
125
floor = get_xp (np )(_aliases .floor )
@@ -121,6 +128,7 @@ def _dask_arange(
121
128
tensordot = get_xp (np )(_aliases .tensordot )
122
129
sign = get_xp (np )(_aliases .sign )
123
130
131
+
124
132
# asarray also adds the copy keyword, which is not present in numpy 1.0.
125
133
def asarray (
126
134
obj : Union [
@@ -135,7 +143,7 @@ def asarray(
135
143
* ,
136
144
dtype : Optional [Dtype ] = None ,
137
145
device : Optional [Device ] = None ,
138
- copy : " Optional[Union[bool, np._CopyMode]]" = None ,
146
+ copy : Optional [Union [bool , np ._CopyMode ]] = None ,
139
147
** kwargs ,
140
148
) -> Array :
141
149
"""
@@ -144,6 +152,8 @@ def asarray(
144
152
See the corresponding documentation in the array library and/or the array API
145
153
specification for more details.
146
154
"""
155
+ # TODO: respect device keyword?
156
+
147
157
if isinstance (obj , da .Array ):
148
158
if dtype is not None and dtype != obj .dtype :
149
159
if copy is False :
@@ -183,15 +193,18 @@ def asarray(
183
193
# Furthermore, the masking workaround in common._aliases.clip cannot work with
184
194
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
185
195
# now).
186
- @get_xp (da )
187
196
def clip (
188
197
x : Array ,
189
198
/ ,
190
199
min : Optional [Union [int , float , Array ]] = None ,
191
200
max : Optional [Union [int , float , Array ]] = None ,
192
- * ,
193
- xp ,
194
201
) -> Array :
202
+ """
203
+ Array API compatibility wrapper for clip().
204
+
205
+ See the corresponding documentation in the array library and/or the array API
206
+ specification for more details.
207
+ """
195
208
def _isscalar (a ):
196
209
return isinstance (a , (int , float , type (None )))
197
210
min_shape = () if _isscalar (min ) else min .shape
@@ -201,19 +214,19 @@ def _isscalar(a):
201
214
result_shape = np .broadcast_shapes (x .shape , min_shape , max_shape )
202
215
203
216
if min is not None :
204
- min = xp .broadcast_to (xp .asarray (min ), result_shape )
217
+ min = da .broadcast_to (da .asarray (min ), result_shape )
205
218
if max is not None :
206
- max = xp .broadcast_to (xp .asarray (max ), result_shape )
219
+ max = da .broadcast_to (da .asarray (max ), result_shape )
207
220
208
221
if min is None and max is None :
209
- return xp .positive (x )
222
+ return da .positive (x )
210
223
211
224
if min is None :
212
- return astype (xp .minimum (x , max ), x .dtype )
225
+ return astype (da .minimum (x , max ), x .dtype )
213
226
if max is None :
214
- return astype (xp .maximum (x , min ), x .dtype )
227
+ return astype (da .maximum (x , min ), x .dtype )
215
228
216
- return astype (xp .minimum (xp .maximum (x , min ), max ), x .dtype )
229
+ return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
217
230
218
231
# exclude these from all since dask.array has no sorting functions
219
232
_da_unsupported = ['sort' , 'argsort' ]
0 commit comments