Description
Several parts of the Array API standard assume that array objects are mutable. Some array API implementations (notably JAX) do not support mutating array objects. This has led to array API implementations currently being developed in scipy
and sklearn
to be entirely unusable in JAX.
Given this, downstream implementations have a few choices:
- Use mutability semantics, excluding libraries like JAX.
- Avoid mutability semantics to support libraries like JAX.
- Explicitly special-case arrays of type
jax.numpy.Array
, changing the implementation logic for that case.
(1) is a bad choice, because it means JAX will not be supported. (2) is a bad choice, because for libraries like NumPy, it leads to excessive copying of buffers, worsening performance. (3) is a bad choice because it hard-codes the presence of specific implementations in a context that is supposed to be implementation-agnostic.
One way the Array API standard could address this is by adding "mutable arrays"
or something similar to the existing capabilities
dict. Then downstream implementations could use strategy (3) without special-casing particular implementations.