Skip to content

RFC: static vs. dynamic shapes and JAX's .at for simulating in-place ops #609

Open
@rgommers

Description

@rgommers

This is a continuation of a discussion that started a few weeks ago in gh-597 (Cc @soraros). It is closely related to gh-84 (boolean indexing) and gh-24 (mutability and copy/views).

I'll copy the content of @soraros's comment here in full:

Start of comment

I also think the problem is more fundamental than that. JAX is essentially a front-end for XLA, and the primitives provided by XLA (for now) require static shape. So the line that actually go wrong is

>>> xs[ix_bool]
array([0, 2, 4])

Note this code does work in JAX, though not jittable, for we don't know its output shape. Let's pretend x[ix_bool] += 1 is syntax sugar for x = x + where(ix_bool, 1, 0) (which works in JAX) for a moment. The same problem appears when we want x[ix_bool] += [1, 3, 5]. Again, we somehow need to know the shape of the rhs, which is equivalent to know the shape of xs[ix_bool] as in the last example.

So what we really work around is the static shape requirement (recall the need of a size parameter for nonzero), which is not exclusively JAX.

Now, for something a bit off-topic.:

I think the JAX style functional syntax a = a.at[...].set(...) for in-place operation looks (and arguably works) better than numpy, and I'd really like to have it for array api. Some pros:

  • Looks familiar, and simulates the feel of in-place operation just fine.
  • Made it clear nothing is modified. This restricted access pattern would work with any accelerator-backed system. I think it would aid static analysis in system like Numba as well.
  • More concise, can be chained, and sometimes express our intention better.
a = zeros(m)       # initialing a
a[I] += arange(n)  # semantically, still initialing a

# VS

# being concise here is not the important point
# this line becomes a "semantical block" for initialisation
a = zeros(m).at[I].add(arange(n))  # initialing a
  • Can specify indexing mode, (more) easily.
# I think these are fairly cumbersome to represent in `numpy`, as we don't have kwargs for __getitem__
b = a.at[I].add(val, unique_indices=True)     # important info for accelerators
c = b.at[J].get(mode='fill', fill_value=nan)  # sure, we have `take`, but this is uniform and cool

Some of my thoughts

  • The last two lines of code, annotating getitem/setitem-like operations with info for accelerators, is an argument that hasn't been made before. If that's something we'd want to support, then this is a way to do it. A context manager would be another way, or like Numba does it (e.g., a boundscheck keyword to @njit).
  • As discussed in Copy-view behaviour and mutating arrays #24, the syntax for x = x.at[... and numpy et al.'s in-place support is completely equivalent when you have a JIT, and numpy's version is more efficient if you don't - as long as you can guarantee that you are not modifying a view. The syntax is also arguably nicer - more concise and more familiar. So, from that perspective, .at isn't ideal.
  • It seems like we do need better static shape support though. The dynamic shape support is marked as optional in the standard, so what's the alternative?

The last point is important. Writing generic code is difficult now when you need, e.g., update values with a mask. Doing that only the JAX way seems like a nonstarter, because it's way too inefficient for NumPy et al. The question though is if there's something that would work for JAX, TF and Dask? Dask also struggles to some extent with dynamic shapes, although most of it now works (xref dask/dask#2000 and dask/dask#7393). @jakirkham any thoughts on whether you need anything more (possibly JAX-like) for dynamic shape support in Dask?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Needs DiscussionNeeds further discussion.RFCRequest for comments. Feature requests and proposed changes.topic: Dynamic ShapesData-dependent shapes.topic: IndexingArray indexing.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions