Description
Describe the issue:
Here is a link to the original discussion at Pymc:
https://discourse.pymc.io/t/pytensor-jax-does-not-support-slicing-arrays-with-a-dynamic-slice-length/12163?u=pbaggens
The problem is that one cannot use a dynamic slice into a shared variable in JAX mode.
When I do this, I get the error:
NotImplementedError: JAX does not support slicing arrays with a dynamic slice length.
The first thing I tried was to disable the test on line 44 of: pytensor/link/jax/dispatch/subtensor.py
After disabling this test, I get this error instead:
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Apply node that caused the error: DeepCopyOp(Subtensor{int32:int32:}.0)
Reproducable code example:
import pytensor
import pytensor.tensor as T
import numpy as np
import collections
seed=15
numpy_rng = np.random.RandomState(seed)
# ---------------------- constants ------------------
dim=100
bs=25
data_size = (100,dim)
batch_input_shape = (bs,dim)
# create shared input data and smaller batch
data = numpy_rng.uniform(low=-1.0, high=1.0, size=data_size).astype(pytensor.config.floatX)
input_data = pytensor.shared(value = data, name = 'input_data')
data = numpy_rng.uniform(low=-1.0, high=1.0, size=batch_input_shape).astype(pytensor.config.floatX)
batch_data = pytensor.shared(value = data, name = 'batch_data')
# create function to fetch a batch from dynamic start adress
inp = T.matrix()
start_index=T.iscalar()
updates = collections.OrderedDict()
updates[batch_data]=input_data[start_index:start_index+bs]
fn = pytensor.function(inputs=[start_index], outputs=[], updates=updates)
for i in range(4):
fn(i*bs)
x=batch_data.get_value()
print(x)
Error message:
No response
PyTensor version information:
Pytensor version 2.11.1
Context for the issue:
This is a standard practice in Neural Networks, to fetch a batch of data from GPU shared memory.
It is more efficient than uploading a batch each time to the GPU. This should
be a high priority issue.