2
2
3
3
from functools import reduce as _reduce , wraps as _wraps
4
4
from builtins import all as _builtin_all , any as _builtin_any
5
- from typing import List , Optional , Sequence , Tuple , Union
5
+ from typing import Any , List , Optional , Sequence , Tuple , Union
6
6
7
7
import torch
8
8
9
9
from .._internal import get_xp
10
10
from ..common import _aliases
11
+ from ..common ._typing import NestedSequence , SupportsBufferProtocol
11
12
from ._info import __array_namespace_info__
12
13
from ._typing import Array , Device , DType
13
14
@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207
208
remainder = _two_arg (torch .remainder )
208
209
subtract = _two_arg (torch .subtract )
209
210
211
+
212
+ def asarray (
213
+ obj : (
214
+ Array
215
+ | bool | int | float | complex
216
+ | NestedSequence [bool | int | float | complex ]
217
+ | SupportsBufferProtocol
218
+ ),
219
+ / ,
220
+ * ,
221
+ dtype : DType | None = None ,
222
+ device : Device | None = None ,
223
+ copy : bool | None = None ,
224
+ ** kwargs : Any ,
225
+ ) -> Array :
226
+ # torch.asarray does not respect input->output device propagation
227
+ # https://github.com/pytorch/pytorch/issues/150199
228
+ if device is None and isinstance (obj , torch .Tensor ):
229
+ device = obj .device
230
+ return torch .asarray (obj , dtype = dtype , device = device , copy = copy , ** kwargs )
231
+
232
+
210
233
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211
234
# of 'axis'.
212
235
@@ -285,7 +308,6 @@ def prod(x: Array,
285
308
dtype : Optional [DType ] = None ,
286
309
keepdims : bool = False ,
287
310
** kwargs ) -> Array :
288
- x = torch .asarray (x )
289
311
ndim = x .ndim
290
312
291
313
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -321,7 +343,6 @@ def sum(x: Array,
321
343
dtype : Optional [DType ] = None ,
322
344
keepdims : bool = False ,
323
345
** kwargs ) -> Array :
324
- x = torch .asarray (x )
325
346
ndim = x .ndim
326
347
327
348
# https://github.com/pytorch/pytorch/issues/29137.
@@ -351,7 +372,6 @@ def any(x: Array,
351
372
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
352
373
keepdims : bool = False ,
353
374
** kwargs ) -> Array :
354
- x = torch .asarray (x )
355
375
ndim = x .ndim
356
376
if axis == ():
357
377
return x .to (torch .bool )
@@ -376,7 +396,6 @@ def all(x: Array,
376
396
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
377
397
keepdims : bool = False ,
378
398
** kwargs ) -> Array :
379
- x = torch .asarray (x )
380
399
ndim = x .ndim
381
400
if axis == ():
382
401
return x .to (torch .bool )
@@ -819,7 +838,7 @@ def sign(x: Array, /) -> Array:
819
838
return out
820
839
821
840
822
- __all__ = ['__array_namespace_info__' , 'result_type' , 'can_cast' ,
841
+ __all__ = ['__array_namespace_info__' , 'asarray' , ' result_type' , 'can_cast' ,
823
842
'permute_dims' , 'bitwise_invert' , 'newaxis' , 'conj' , 'add' ,
824
843
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
825
844
'bitwise_right_shift' , 'bitwise_xor' , 'copysign' , 'count_nonzero' ,
0 commit comments