Skip to content

BUG: Pytensor-JAX unable to use dynamic slice indexing  #312

Open
@pbaggens1

Description

@pbaggens1

Describe the issue:

Here is a link to the original discussion at Pymc:
https://discourse.pymc.io/t/pytensor-jax-does-not-support-slicing-arrays-with-a-dynamic-slice-length/12163?u=pbaggens

The problem is that one cannot use a dynamic slice into a shared variable in JAX mode.
When I do this, I get the error:

NotImplementedError: JAX does not support slicing arrays with a dynamic slice length.

The first thing I tried was to disable the test on line 44 of: pytensor/link/jax/dispatch/subtensor.py

After disabling this test, I get this error instead:

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Apply node that caused the error: DeepCopyOp(Subtensor{int32:int32:}.0)

Reproducable code example:

import pytensor
import pytensor.tensor as T
import numpy as np
import collections

seed=15
numpy_rng = np.random.RandomState(seed)

# ---------------------- constants  ------------------
dim=100
bs=25
data_size = (100,dim)
batch_input_shape = (bs,dim)

# create shared input data and smaller batch
data = numpy_rng.uniform(low=-1.0, high=1.0, size=data_size).astype(pytensor.config.floatX)
input_data = pytensor.shared(value = data, name = 'input_data')
data = numpy_rng.uniform(low=-1.0, high=1.0, size=batch_input_shape).astype(pytensor.config.floatX)
batch_data = pytensor.shared(value = data, name = 'batch_data')

# create function to fetch a batch from dynamic start adress
inp = T.matrix()
start_index=T.iscalar()
updates = collections.OrderedDict()
updates[batch_data]=input_data[start_index:start_index+bs]
fn = pytensor.function(inputs=[start_index], outputs=[], updates=updates)

for i in range(4):
   fn(i*bs)
   x=batch_data.get_value()
   print(x)

Error message:

No response

PyTensor version information:

Pytensor version 2.11.1

Context for the issue:

This is a standard practice in Neural Networks, to fetch a batch of data from GPU shared memory.
It is more efficient than uploading a batch each time to the GPU. This should
be a high priority issue.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions