Skip to content

DLPack and readonly/immutable arrays #191

Closed
@seberg

Description

@seberg

Sorry if there is some more discussion (or just clarification) in the pipeline. But I still think this is important, so here is an issue :). I realize that there was some related discussion in JAX before, but I am not clear how deep it really was. (I have talked with Ralf about this – we disagree about how important this is – but I wonder what others think, and we need to at least have an issue/clear decision somwhere.)


The problem (and current state/issues also with the buffer protocol that supports "readonly")

NumPy (and the buffer protocol) has readonly (writeable=False arrays). NumPy actually has a slight distinction (I do not think the matters, though):

  • Readonly arrays for which NumPy owns the memory. These can be set to writeable by a (power) user.
  • Arrays that are truly readonly.

The current situation in the "buffer protocol" (and array protocols) world is that NumPy supports a writeable and readonly arrays, JAX is immutable (and thus readonly), while PyTorch is always writable:

import jax
import numpy as np
import torch

JAX to NumPy works correctly:

jax_tensor = jax.numpy.array([1., 2., 3.], dtype="float32")
numpy_array = np.asarray(jax_tensor)
assert not numpy_array.flags.writeable  # JAX exports as readonly.

NumPy to JAX ignores that NumPy is writeable when importing (same for memoryview):

numpy_arr = np.arange(20).astype(np.float32)
jax_tensor = jax.numpy.asarray(numpy_arr)  # JAX imports mutable array though
numpy_arr[0] = 5.
print(repr(jax_tensor))
# DeviceArray([5., 1., 2.], dtype=float32)

PyTorch also breaks the first part when talking to JAX (although you have to go via ndarray, I guess):

jax_tensor = jax.numpy.array([1., 2., 3.], dtype="float32")
torch_tensor = torch.Tensor(np.asarray(jax_tensor))
# UserWarning: The given NumPy array is not writeable, ... (scary warning!)
torch_tensor[0] = 5  # modifies the "immutable" jax tensor
print(repr(jax_tensor))
# DeviceArray([5., 2., 3.], dtype=float32)

And NumPy arrays can be backed by truly read-only memory:

arr = np.arange(20).astype(np.float32)  # torch default dtype
np.save("/tmp/test-array.npy", arr)
# The following is memory mapped read-only:
arr = np.load("/tmp/test-array.npy", mmap_mode="r")
torch_tensor = torch.Tensor(arr)  # (scary warning here)
torch_tensor[0] = 5.
# segmentation fault

This is in a world where "readonly" information exists, but JAX and PyTorch don't support it, and we leave it up to the user to know about these limitations. Because of that pretty severe limitation PyTorch decides to give that scary (and ugly) warning.
JAX tries to protect the user during export – although not in DLPack – but silently accepts that it is the users problem during import.

I do realize that within the "array API", you are not supposed to write to an array. So within the "array API" world everything is read-only and that would solve the problem. But that feels like a narrow view to me: we want DLPack to be widely adopted and most Array API users are also NumPy, PyTorch, JAX, etc. users (or interact with them). So they should expect to interact with NumPy where mutability is common. Also __dlpack__ could very much be more generally useful than the Array API itself.

Both JAX (immutable) and PyTorch (always writeable) have limitations that their users must be aware of when exchanging data currently. But it feels strange to me to force these limitations on NumPy. Especially, I do not like them in np.asarray(dlpack_object). To get around the limitations in np.asarray we would have to:

  • Import all DLPack arrays as readonly or always copy? This could be part of the standard. In that case it would make the second point unnecessary. But: That prevents any in-place algorithms from accepting DLPack directly.
  • Possibly, export only writeable arrays (to play safe with PyTorch). Seems fine to me, at least for now (a bit weird if combined with first point, and doesn't round-trip)

Clarifying something like that is probably sufficient, at least for now.


But IMO, the proper fix is to add a read-only bit to DLPack (or the __dlpack__ protocol). Adding that (not awkwardly) requires either extending the API (e.g. having a new struct to query metadata) or breaking ABI. I don't know what the solution is, but whatever it is, I would expect that DLPack is future-proofed anyway.

Once DLPack is future-proofed, the decision of adding a flag could also be deferred to a future version…

As is, I am not sure NumPy should aim to talk preferably in __dlpack__ in the future (whether likely to happen or not). Rather, it feels like NumPy should support DLPack mainly for the sake of those who choose to use it. (Unlike the buffer-protocol, which users use without even knowing, e.g. when writing cython typed memoryviews.)

Neither JAX nor pytorch currently quite support "readonly" properly (and maybe never will). But I do not think that limitation is an argument against supporting it properly in __dlpack__. NumPy, dask, cupy, cython (typed memoryviews) do support it properly after all. It seems almost like turning a PyTorch "user problem" into an ecosystem wide "user problem"?

Of course, NumPy can continue talking buffer-protocol with cython, and many others (and likely will do in any case). And I can live with the limitations at least in an np.from_dlpack. But I don't like them in np.asarray(), and they just seem like unnecessary issues to me. (In that case, I may still prefer not to export readonly arrays.)


Or am I the only person who thinks this is an important problem that we have to solve for the user, rather than expect the user to be aware of the limitations?


EDIT: xref discussion about it at cupy, that was mentioning that nobody supports the readonly flag which __cuda_array_interface__ actually includes, and asked for support to cupy. (I am not sure why C-level matters – unless it is high level C-API – the Python API is what matters most? NumPy has a PyArray_FailUnlessWriteable() function for this in the public API.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions