Skip to content

Add support for max_version, dl_device, copy kwargs in __dlpack__ to match Array API #20198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map, hashed_index)
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.typing import ArrayLike
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method


Expand Down Expand Up @@ -404,11 +404,25 @@ def __array__(self, dtype=None, context=None, copy=None):
kwds = {} if copy is None else {'copy': copy}
return np.asarray(self._value, dtype=dtype, **kwds)

def __dlpack__(self, *, stream: int | Any | None = None):
if len(self._arrays) != 1:
raise BufferError("__dlpack__ only supported for unsharded arrays.")
def __dlpack__(self, *, stream: int | Any | None = None,
max_version: tuple[int, int] | None = None,
dl_device: tuple[DLDeviceType, int] | None = None,
copy: bool | None = None):
from jax._src.dlpack import to_dlpack # pylint: disable=g-import-not-at-top
return to_dlpack(self, stream=stream)

device_set = self.sharding.device_set
if len(device_set) > 1:
raise BufferError(
"to_dlpack can only pack a dlpack tensor from an array on a singular "
f"device, but an array with a Sharding over {len(device_set)} devices "
"was provided."
)
device, = device_set
return to_dlpack(self, stream=stream,
max_version=max_version,
src_device=device,
dl_device=dl_device,
copy=copy)

def __dlpack_device__(self) -> tuple[enum.Enum, int]:
if len(self._arrays) != 1:
Expand Down
117 changes: 93 additions & 24 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@

from __future__ import annotations

import enum
from typing import Any
import warnings

from jax._src.api import device_put
from jax import numpy as jnp
from jax._src import array
from jax._src import xla_bridge
from jax._src.lax.lax import _array_copy
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.typing import Array
from jax._src.typing import Array, DLDeviceType
from jax._src.sharding import Sharding

DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)

# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
# because their hashes are different.
Expand All @@ -43,45 +45,112 @@
SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})


# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):
kDLCPU = 1
kDLCUDA = 2
kDLROCM = 10
def _to_dlpack(x: Array, stream: int | Any | None,
src_device: xla_client.Device | None = None,
device: xla_client.Device | None = None,
copy: bool | None = None):

if src_device is None:
src_device, = x.devices()
if device and (src_device is None or device != src_device):
if copy is not None and not copy:
raise ValueError(
f"Specified {device=} which requires a copy since the source device "
f"is {repr(src_device)}, however copy=False. Set copy=True or "
"copy=None to perform the requested operation."
)
else:
arr = device_put(x, device)
else:
arr = _array_copy(x) if copy else x
return xla_client._xla.buffer_to_dlpack_managed_tensor(
arr.addressable_data(0), stream=stream
)

def to_dlpack(x: Array, take_ownership: bool = False,
stream: int | Any | None = None):
def to_dlpack(x: Array, stream: int | Any | None = None,
src_device: xla_client.Device | None = None,
dl_device: tuple[DLDeviceType, int] | None = None,
max_version: tuple[int, int] | None = None,
copy : bool | None = None):
"""Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.

Args:
x: a :class:`~jax.Array`, on either CPU or GPU.
take_ownership: Deprecated. It is a no-op to set take_ownership. Will be
deleted in 01/2024.
stream: optional platform-dependent stream to wait on until the buffer is
ready. This corresponds to the `stream` argument to ``__dlpack__``
documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
src_device: either a CPU or GPU :class:`~jax.Device`.
dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
format e.g. as produced by ``__dlpack_device__``.
max_version: the maximum DLPack version that the consumer (i.e. caller of
``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
This function is not guaranteed to return a capsule of version
``max_version``.
copy: a boolean indicating whether or not to copy the input. If
``copy=True`` then the function must always copy. When
``copy=False`` then the function must never copy, and must raise an error
when a copy is deemed necessary. If ``copy=None`` then the function must
avoid a copy if possible but may copy if needed.

Returns:
A dlpack PyCapsule object.
A DLPack PyCapsule object.

Note:
While JAX arrays are always immutable, dlpack buffers cannot be marked as
immutable, and it is possible for processes external to JAX to mutate them
in-place. If a dlpack buffer derived from a JAX array is mutated, it may
lead to undefined behavior when using the associated JAX array.
While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
cannot be marked as immutable, and it is possible for processes external
to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
is mutated, it may lead to undefined behavior when using the associated JAX
array. When JAX eventually supports ``DLManagedTensorVersioned``
(DLPack 1.0), it will be possible to specify that a buffer is read-only.
"""
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
if take_ownership:
warnings.warn(
"take_ownership in to_dlpack is deprecated and it is a no-op."

device = None
dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
if dl_device_type:
try:
dl_device_platform = {
DLDeviceType.kDLCPU: "cpu",
DLDeviceType.kDLCUDA: "cuda",
DLDeviceType.kDLROCM: "rocm",
}[dl_device_type]
backend = xla_bridge.get_backend(dl_device_platform)
device = backend.device_from_local_hardware_id(local_hardware_id)
except TypeError:
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
# recommends using BufferError.
raise BufferError(
"The device specification passed to to_dlpack contains an unsupported "
f"device type (DLDeviceType: {dl_device_type})")

# As new versions are adopted over time, we can maintain some legacy paths
# for compatability mediated through the max_version parameter.
# TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
# supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
# current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
if max_version is None or max_version >= DLPACK_VERSION:
# Latest
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
elif max_version >= MIN_DLPACK_VERSION:
# Oldest supported
return _to_dlpack(
x, stream=stream,
src_device=src_device,
device=device,
copy=copy
)
else:
raise BufferError(
f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
f"version ({max_version}) was requested."
)
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), stream=stream
) # type: ignore

def _place_array(_arr, device, dlpack_device, copy):
if device and dlpack_device != device:
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from collections.abc import Sequence
from typing import Any, Protocol, Union
import numpy as np
import enum

from jax._src.basearray import (
Array as Array,
Expand Down Expand Up @@ -83,3 +84,9 @@ def shape(self) -> Shape: ...
class DeprecatedArg:
def __repr__(self):
return "Deprecated"

# Mirror of dlpack.h enum
class DLDeviceType(enum.IntEnum):
kDLCPU = 1
kDLCUDA = 2
kDLROCM = 10
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _arg_jax_to_tf(arg_jax):
if (isinstance(arg_jax, jax.Array) and
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
arg_dlpack = jax.dlpack.to_dlpack(arg_jax)
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
# The following avoids copies to the host on CPU, always for Array
# and even for ndarray if they are sufficiently aligned.
Expand Down
49 changes: 37 additions & 12 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,48 @@ def setUp(self):
@jtu.sample_product(
shape=all_shapes,
dtype=dlpack_dtypes,
gpu=[False, True],
copy=[False, True, None]
)
def testJaxRoundTrip(self, shape, dtype, gpu):
@jtu.run_on_devices("gpu")
def testJaxRoundTrip(self, shape, dtype, copy):
if xb.using_pjrt_c_api():
self.skipTest("DLPack support is incomplete in the PJRT C API") # TODO(skyewm)
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
if gpu and jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("Skipping GPU test case on CPU")
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y)

def _check_copy(x: jax.Array, y: jax.Array, expect_copy):
copied = x.unsafe_buffer_pointer() != y.unsafe_buffer_pointer()
assert copied == expect_copy, f"Expected {'a' if expect_copy else 'no'} copy"

# Check if the source device is preserved
x = jax.device_put(np, jax.devices("cpu")[0])
device = jax.devices("gpu")[0]
y = jax.device_put(x, device)
dl_device = y.__dlpack_device__()
dlpack = jax.dlpack.to_dlpack(y, copy=copy)
z = jax.dlpack.from_dlpack(dlpack)

self.assertEqual(z.devices(), {device})
self.assertAllClose(np.astype(x.dtype), z)
self.assertRaisesRegex(RuntimeError,
"DLPack tensor may be consumed at most once",
lambda: jax.dlpack.from_dlpack(dlpack))
"DLPack tensor may be consumed at most once",
lambda: jax.dlpack.from_dlpack(dlpack))

if shape in nonempty_array_shapes:
_check_copy(y, z, bool(copy))

# Check if the destination device can be specified
make_dlpack = lambda: x.__dlpack__(dl_device=dl_device, copy=copy)
if copy == False:
self.assertRaisesRegex(ValueError, "copy=False", make_dlpack)
return

z = jax.dlpack.from_dlpack(make_dlpack())
self.assertEqual(z.devices(), {device})
self.assertAllClose(x, z)

if shape in nonempty_array_shapes:
_check_copy(x, z, True)

@jtu.sample_product(
shape=all_shapes,
Expand Down