Skip to content

BUG: JAX Scan fails for outputs with dims > 1 #287

Closed
@jessegrabowski

Description

@jessegrabowski

Describe the issue:

The JAX scan dispatcher implicitly assumes that outputs are 1d, which causes an error to be raised in cases with nd outputs. The offending code seems to be here:

full_trace = jnp.concatenate(
    [jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
    axis=0,
)

If trace is nd, init_state needs a batch dimension added here to allow them to be concatenated.

Reproducable code example:

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode

x0 = pt.dvector('x0')
A = pt.dmatrix('A')

output, _ = pytensor.scan(lambda X, A: A @ X, 
                         non_sequences=[A],
                         outputs_info=[x0],
                         n_steps=100,
                         mode=get_mode('JAX'))
f = pytensor.function([A, x0], [output], mode=get_mode('JAX'))
f(np.random.normal(size=(3,3)), np.random.normal(size=3))

Error message:

TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
    199 for thunk, node, old_storage in zip(
    200     thunks, order, post_thunk_old_storage
    201 ):
--> 202     thunk()
    203     for old_s in old_storage:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 12 frame]

File /tmp/tmp5y3rrz86:13, in jax_funcified_fgraph(A, x0)
     12 # forall_inplace,cpu,scan_fn}(TensorConstant{100}, IncSubtensor{InplaceSet;:int64:}.0, A)
---> 13 tensor_variable_5 = scan(tensor_constant_1, tensor_variable_4, A)
     14 return (tensor_variable_5,)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:187, in jax_funcify_Scan.<locals>.scan(*outer_inputs)
    185     return list(shared_outs)
--> 187 scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
    189 if len(scan_outs_final) == 1:

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:157, in jax_funcify_Scan.<locals>.scan.<locals>.get_partial_traces(traces)
    155 if init_state is not None:
    156     # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
--> 157     full_trace = jnp.concatenate(
    158         [jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
    159         axis=0,
    160     )
    161     buffer_size = buffer.shape[0]

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1774, in concatenate(arrays, axis, dtype)
   1773 while len(arrays_out) > 1:
-> 1774   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1775                 for i in range(0, len(arrays_out), k)]
   1776 return arrays_out[0]

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1774, in <listcomp>(.0)
   1773 while len(arrays_out) > 1:
-> 1774   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1775                 for i in range(0, len(arrays_out), k)]
   1776 return arrays_out[0]

    [... skipping hidden 7 frame]

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:2908, in _concatenate_shape_rule(*operands, **kwargs)
   2907   msg = "Cannot concatenate arrays with different numbers of dimensions: got {}."
-> 2908   raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
   2909 if not 0 <= dimension < operands[0].ndim:

TypeError: Cannot concatenate arrays with different numbers of dimensions: got (3,), (100, 3).

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[18], line 1
----> 1 f(np.random.normal(size=(3,3)), np.random.normal(size=3))

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:206, in streamline.<locals>.streamline_default_f()
    204             old_s[0] = None
    205 except Exception:
--> 206     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    530     warnings.warn(
    531         f"{exc_type} error does not allow us to add an extra error message"
    532     )
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
    198 try:
    199     for thunk, node, old_storage in zip(
    200         thunks, order, post_thunk_old_storage
    201     ):
--> 202         thunk()
    203         for old_s in old_storage:
    204             old_s[0] = None

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
    672         compute_map[o_var][0] = True

    [... skipping hidden 12 frame]

File /tmp/tmp5y3rrz86:13, in jax_funcified_fgraph(A, x0)
     11 tensor_variable_4 = incsubtensor(tensor_variable_3, tensor_variable_1, scalar_constant)
     12 # forall_inplace,cpu,scan_fn}(TensorConstant{100}, IncSubtensor{InplaceSet;:int64:}.0, A)
---> 13 tensor_variable_5 = scan(tensor_constant_1, tensor_variable_4, A)
     14 return (tensor_variable_5,)

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:187, in jax_funcify_Scan.<locals>.scan(*outer_inputs)
    184     shared_outs = inner_out_shared[: info.n_shared_outs]
    185     return list(shared_outs)
--> 187 scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
    189 if len(scan_outs_final) == 1:
    190     scan_outs_final = scan_outs_final[0]

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:157, in jax_funcify_Scan.<locals>.scan.<locals>.get_partial_traces(traces)
    154 for init_state, trace, buffer in zip(init_states, traces, buffers):
    155     if init_state is not None:
    156         # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
--> 157         full_trace = jnp.concatenate(
    158             [jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
    159             axis=0,
    160         )
    161         buffer_size = buffer.shape[0]
    162     else:
    163         # NIT-SOT: Buffer is just the number of entries that should be returned

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1774, in concatenate(arrays, axis, dtype)
   1772 k = 16
   1773 while len(arrays_out) > 1:
-> 1774   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1775                 for i in range(0, len(arrays_out), k)]
   1776 return arrays_out[0]

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1774, in <listcomp>(.0)
   1772 k = 16
   1773 while len(arrays_out) > 1:
-> 1774   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1775                 for i in range(0, len(arrays_out), k)]
   1776 return arrays_out[0]

    [... skipping hidden 7 frame]

File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:2908, in _concatenate_shape_rule(*operands, **kwargs)
   2906 if len({operand.ndim for operand in operands}) != 1:
   2907   msg = "Cannot concatenate arrays with different numbers of dimensions: got {}."
-> 2908   raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
   2909 if not 0 <= dimension < operands[0].ndim:
   2910   msg = "concatenate dimension out of bounds: dimension {} for shapes {}."

TypeError: Cannot concatenate arrays with different numbers of dimensions: got (3,), (100, 3).
Apply node that caused the error: forall_inplace,cpu,scan_fn}(TensorConstant{100}, IncSubtensor{InplaceSet;:int64:}.0, A)
Toposort index: 5
Inputs types: [TensorType(int8, ()), TensorType(float64, (100, ?)), TensorType(float64, (?, ?))]
Inputs shapes: [(3, 3), (3,)]
Inputs strides: [(24, 8), (8,)]
Inputs values: ['not shown', array([ 1.17146832, -1.98572717, -0.50107902])]
Outputs clients: [['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

PyTensor version information:

2.11.2

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions