Skip to content

Commit a0e6878

Browse files
committed
Merge branch 'main' into iinfo
2 parents f80f157 + 621494b commit a0e6878

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
5-
from typing import List, Optional, Sequence, Tuple, Union
5+
from typing import Any, List, Optional, Sequence, Tuple, Union
66

77
import torch
88

99
from .._internal import get_xp
1010
from ..common import _aliases
11+
from ..common._typing import NestedSequence, SupportsBufferProtocol
1112
from ._info import __array_namespace_info__
1213
from ._typing import Array, Device, DType
1314

@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207208
remainder = _two_arg(torch.remainder)
208209
subtract = _two_arg(torch.subtract)
209210

211+
212+
def asarray(
213+
obj: (
214+
Array
215+
| bool | int | float | complex
216+
| NestedSequence[bool | int | float | complex]
217+
| SupportsBufferProtocol
218+
),
219+
/,
220+
*,
221+
dtype: DType | None = None,
222+
device: Device | None = None,
223+
copy: bool | None = None,
224+
**kwargs: Any,
225+
) -> Array:
226+
# torch.asarray does not respect input->output device propagation
227+
# https://github.com/pytorch/pytorch/issues/150199
228+
if device is None and isinstance(obj, torch.Tensor):
229+
device = obj.device
230+
return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
231+
232+
210233
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211234
# of 'axis'.
212235

@@ -285,7 +308,6 @@ def prod(x: Array,
285308
dtype: Optional[DType] = None,
286309
keepdims: bool = False,
287310
**kwargs) -> Array:
288-
x = torch.asarray(x)
289311
ndim = x.ndim
290312

291313
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -321,7 +343,6 @@ def sum(x: Array,
321343
dtype: Optional[DType] = None,
322344
keepdims: bool = False,
323345
**kwargs) -> Array:
324-
x = torch.asarray(x)
325346
ndim = x.ndim
326347

327348
# https://github.com/pytorch/pytorch/issues/29137.
@@ -351,7 +372,6 @@ def any(x: Array,
351372
axis: Optional[Union[int, Tuple[int, ...]]] = None,
352373
keepdims: bool = False,
353374
**kwargs) -> Array:
354-
x = torch.asarray(x)
355375
ndim = x.ndim
356376
if axis == ():
357377
return x.to(torch.bool)
@@ -376,7 +396,6 @@ def all(x: Array,
376396
axis: Optional[Union[int, Tuple[int, ...]]] = None,
377397
keepdims: bool = False,
378398
**kwargs) -> Array:
379-
x = torch.asarray(x)
380399
ndim = x.ndim
381400
if axis == ():
382401
return x.to(torch.bool)
@@ -819,7 +838,7 @@ def sign(x: Array, /) -> Array:
819838
return out
820839

821840

822-
__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
841+
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
823842
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
824843
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
825844
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',

array_api_compat/torch/_typing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
__all__ = ["Array", "DType", "Device"]
1+
__all__ = ["Array", "Device", "DType"]
22

3-
from torch import dtype as DType, Tensor as Array
4-
from ..common._typing import Device
3+
from torch import device as Device, dtype as DType, Tensor as Array

0 commit comments

Comments
 (0)