Skip to content

Commit 46e1060

Browse files
committed
Add support for device kwarg in astype to match Array API
1 parent 604f5ec commit 46e1060

File tree

6 files changed

+81
-10
lines changed

6 files changed

+81
-10
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
88

99
## jax 0.4.26
1010

11+
* New Features
12+
* {func}`jax.numpy.astype` supports new `device` keyword argument.
13+
1114
* Deprecations & Removals
1215
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
1316
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
@@ -28,6 +31,12 @@ Remember to align the itemized text with the first line of an item within a list
2831
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
2932
and `jax.extend.source_info_util` instead.
3033

34+
* Bug fixes
35+
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
36+
Previously, no copy would be made when the output array would have the same
37+
dtype as the input array. This may result in some increased memory usage.
38+
To prevent copying when possible, set `copy=False`.
39+
3140
## jaxlib 0.4.26
3241

3342
## jax 0.4.25 (Feb 26, 2024)

jax/_src/numpy/array_methods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@
3131
import numpy as np
3232
import jax
3333
from jax import lax
34+
from jax.sharding import Sharding
3435
from jax._src import core
3536
from jax._src import dtypes
3637
from jax._src.api_util import _ensure_index_tuple
3738
from jax._src.array import ArrayImpl
3839
from jax._src.lax import lax as lax_internal
40+
from jax._src.lib import xla_client as xc
3941
from jax._src.numpy import lax_numpy
4042
from jax._src.numpy import reductions
4143
from jax._src.numpy import ufuncs
@@ -55,15 +57,15 @@
5557
# functions, which can themselves handle instances from any of these classes.
5658

5759

58-
def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array:
60+
def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
5961
"""Copy the array and cast to a specified dtype.
6062
6163
This is implemented via :func:`jax.lax.convert_element_type`, which may
6264
have slightly different behavior than :meth:`numpy.ndarray.astype` in
6365
some cases. In particular, the details of float-to-int and int-to-float
6466
casts are implementation dependent.
6567
"""
66-
return lax_numpy.astype(arr, dtype)
68+
return lax_numpy.astype(arr, dtype, copy=copy, device=device)
6769

6870

6971
def _nbytes(arr: ArrayLike) -> int:

jax/_src/numpy/lax_numpy.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import opt_einsum
4242

4343
import jax
44-
from jax import jit
44+
from jax import jit, device_put
4545
from jax import errors
4646
from jax import lax
4747
from jax.sharding import Sharding, SingleDeviceSharding
@@ -2209,19 +2209,38 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
22092209
else:
22102210
return x
22112211

2212-
22132212
@util.implements(getattr(np, "astype", None), lax_description="""
22142213
This is implemented via :func:`jax.lax.convert_element_type`, which may
22152214
have slightly different behavior than :func:`numpy.astype` in some cases.
22162215
In particular, the details of float-to-int and int-to-float casts are
22172216
implementation dependent.
22182217
""")
2219-
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array:
2220-
del copy # unused in JAX
2218+
def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array:
22212219
if dtype is None:
22222220
dtype = dtypes.canonicalize_dtype(float_)
22232221
dtypes.check_user_dtype_supported(dtype, "astype")
2224-
return lax.convert_element_type(x, dtype)
2222+
src_dtype = x.dtype if hasattr(x, "dtype") else dtypes.dtype(x)
2223+
if (
2224+
src_dtype is not None
2225+
and dtypes.isdtype(src_dtype, "complex floating")
2226+
and dtypes.isdtype(dtype, ("integral", "real floating"))
2227+
):
2228+
raise ValueError(
2229+
"Casting from complex to non-complex dtypes is not permitted. Please "
2230+
"first use jnp.real or jnp.imag to take the real/imaginary component of "
2231+
"your input."
2232+
)
2233+
src_devices = (
2234+
x.devices() if hasattr(x, "devices")
2235+
and not isinstance(x, core.Tracer) else None
2236+
)
2237+
arr = x
2238+
if device is not None and src_devices != {device}:
2239+
arr = device_put(x, device)
2240+
elif copy:
2241+
arr = _array_copy(x)
2242+
return lax.convert_element_type(arr, dtype)
2243+
22252244

22262245

22272246
@util.implements(np.asarray, lax_description=_ARRAY_DOC)

jax/experimental/array_api/_data_type_functions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
1516

17+
import builtins
1618
import functools
1719
from typing import NamedTuple
1820
import jax
1921
import jax.numpy as jnp
2022

2123

24+
from jax._src.lib import xla_client as xc
25+
from jax._src.sharding import Sharding
2226
from jax.experimental.array_api._dtypes import (
2327
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
2428
float32, float64, complex64, complex128
@@ -124,8 +128,8 @@ def _promote_types(t1, t2):
124128
raise ValueError("No promotion path for {t1} & {t2}")
125129

126130

127-
def astype(x, dtype, /, *, copy=True):
128-
return jnp.array(x, dtype=dtype, copy=copy)
131+
def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None):
132+
return jnp.astype(x, dtype, copy=copy, device=device)
129133

130134

131135
def can_cast(from_, to, /):

jax/numpy/__init__.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ from jax._src import dtypes as _dtypes
99
from jax._src.lax.lax import PrecisionLike
1010
from jax._src.lax.slicing import GatherScatterMode
1111
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
12+
from jax._src.sharding import Sharding
13+
from jax._src.lib import xla_client as xc
1214
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
1315
from jax.numpy import fft as fft, linalg as linalg
1416
from jax.sharding import Sharding as _Sharding
@@ -112,7 +114,7 @@ def asarray(
112114
) -> Array: ...
113115
def asin(x: ArrayLike, /) -> Array: ...
114116
def asinh(x: ArrayLike, /) -> Array: ...
115-
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ...
117+
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ...
116118
def atan(x: ArrayLike, /) -> Array: ...
117119
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
118120
def atanh(x: ArrayLike, /) -> Array: ...

tests/lax_numpy_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3789,6 +3789,41 @@ def testAstype(self, from_dtype, to_dtype, use_method):
37893789
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
37903790
self._CompileAndCheck(jnp_op, args_maker)
37913791

3792+
@jtu.sample_product(
3793+
change_dtype=[True, False],
3794+
copy=[True, False],
3795+
change_device=[True, False],
3796+
)
3797+
def testAstypeCopy(self, change_dtype, copy, change_device):
3798+
if jax.device_count() == 1 and change_device:
3799+
raise unittest.SkipTest(
3800+
"Testing device transfer requires at least two available devices."
3801+
)
3802+
3803+
dtype = 'float32' if change_dtype else 'int32'
3804+
device = jax.devices()[-1] if change_device else None
3805+
expect_copy = change_dtype or copy or change_device
3806+
x = jnp.arange(5, dtype='int32')
3807+
y = x.astype(dtype, copy=copy, device=device)
3808+
3809+
assert y.dtype == dtype
3810+
if change_device:
3811+
assert y.devices() == {device}
3812+
else:
3813+
y.delete()
3814+
get_val = lambda: np.array(x)
3815+
err_msg = "Array has been deleted"
3816+
if expect_copy:
3817+
get_val()
3818+
else:
3819+
jtu.check_raises(get_val, RuntimeError, err_msg)
3820+
3821+
def testAstypeComplexDowncast(self):
3822+
x = jnp.array(2.0+1.5j, dtype='complex64')
3823+
complex_downcast = lambda: x.astype('float32')
3824+
err_msg = "Casting from complex to non-complex "
3825+
jtu.check_raises(complex_downcast, ValueError, err_msg)
3826+
37923827
def testAstypeInt4(self):
37933828
# Test converting from int4 to int8
37943829
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)

0 commit comments

Comments
 (0)