Closed
Description
Describe the issue:
The JAX scan dispatcher implicitly assumes that outputs are 1d, which causes an error to be raised in cases with nd outputs. The offending code seems to be here:
full_trace = jnp.concatenate(
[jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
axis=0,
)
If trace
is nd
, init_state
needs a batch dimension added here to allow them to be concatenated.
Reproducable code example:
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode
x0 = pt.dvector('x0')
A = pt.dmatrix('A')
output, _ = pytensor.scan(lambda X, A: A @ X,
non_sequences=[A],
outputs_info=[x0],
n_steps=100,
mode=get_mode('JAX'))
f = pytensor.function([A, x0], [output], mode=get_mode('JAX'))
f(np.random.normal(size=(3,3)), np.random.normal(size=3))
Error message:
TypeError Traceback (most recent call last)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
199 for thunk, node, old_storage in zip(
200 thunks, order, post_thunk_old_storage
201 ):
--> 202 thunk()
203 for old_s in old_storage:
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
663 def thunk(
664 fgraph=self.fgraph,
665 fgraph_jit=fgraph_jit,
666 thunk_inputs=thunk_inputs,
667 thunk_outputs=thunk_outputs,
668 ):
--> 669 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
671 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
[... skipping hidden 12 frame]
File /tmp/tmp5y3rrz86:13, in jax_funcified_fgraph(A, x0)
12 # forall_inplace,cpu,scan_fn}(TensorConstant{100}, IncSubtensor{InplaceSet;:int64:}.0, A)
---> 13 tensor_variable_5 = scan(tensor_constant_1, tensor_variable_4, A)
14 return (tensor_variable_5,)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:187, in jax_funcify_Scan.<locals>.scan(*outer_inputs)
185 return list(shared_outs)
--> 187 scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
189 if len(scan_outs_final) == 1:
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:157, in jax_funcify_Scan.<locals>.scan.<locals>.get_partial_traces(traces)
155 if init_state is not None:
156 # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
--> 157 full_trace = jnp.concatenate(
158 [jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
159 axis=0,
160 )
161 buffer_size = buffer.shape[0]
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1774, in concatenate(arrays, axis, dtype)
1773 while len(arrays_out) > 1:
-> 1774 arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
1775 for i in range(0, len(arrays_out), k)]
1776 return arrays_out[0]
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1774, in <listcomp>(.0)
1773 while len(arrays_out) > 1:
-> 1774 arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
1775 for i in range(0, len(arrays_out), k)]
1776 return arrays_out[0]
[... skipping hidden 7 frame]
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:2908, in _concatenate_shape_rule(*operands, **kwargs)
2907 msg = "Cannot concatenate arrays with different numbers of dimensions: got {}."
-> 2908 raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
2909 if not 0 <= dimension < operands[0].ndim:
TypeError: Cannot concatenate arrays with different numbers of dimensions: got (3,), (100, 3).
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[18], line 1
----> 1 f(np.random.normal(size=(3,3)), np.random.normal(size=3))
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
967 t0_fn = time.perf_counter()
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
975 restore_defaults()
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:206, in streamline.<locals>.streamline_default_f()
204 old_s[0] = None
205 except Exception:
--> 206 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
530 warnings.warn(
531 f"{exc_type} error does not allow us to add an extra error message"
532 )
533 # Some exception need extra parameter in inputs. So forget the
534 # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
198 try:
199 for thunk, node, old_storage in zip(
200 thunks, order, post_thunk_old_storage
201 ):
--> 202 thunk()
203 for old_s in old_storage:
204 old_s[0] = None
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
663 def thunk(
664 fgraph=self.fgraph,
665 fgraph_jit=fgraph_jit,
666 thunk_inputs=thunk_inputs,
667 thunk_outputs=thunk_outputs,
668 ):
--> 669 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
671 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
672 compute_map[o_var][0] = True
[... skipping hidden 12 frame]
File /tmp/tmp5y3rrz86:13, in jax_funcified_fgraph(A, x0)
11 tensor_variable_4 = incsubtensor(tensor_variable_3, tensor_variable_1, scalar_constant)
12 # forall_inplace,cpu,scan_fn}(TensorConstant{100}, IncSubtensor{InplaceSet;:int64:}.0, A)
---> 13 tensor_variable_5 = scan(tensor_constant_1, tensor_variable_4, A)
14 return (tensor_variable_5,)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:187, in jax_funcify_Scan.<locals>.scan(*outer_inputs)
184 shared_outs = inner_out_shared[: info.n_shared_outs]
185 return list(shared_outs)
--> 187 scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)
189 if len(scan_outs_final) == 1:
190 scan_outs_final = scan_outs_final[0]
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:157, in jax_funcify_Scan.<locals>.scan.<locals>.get_partial_traces(traces)
154 for init_state, trace, buffer in zip(init_states, traces, buffers):
155 if init_state is not None:
156 # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
--> 157 full_trace = jnp.concatenate(
158 [jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
159 axis=0,
160 )
161 buffer_size = buffer.shape[0]
162 else:
163 # NIT-SOT: Buffer is just the number of entries that should be returned
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1774, in concatenate(arrays, axis, dtype)
1772 k = 16
1773 while len(arrays_out) > 1:
-> 1774 arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
1775 for i in range(0, len(arrays_out), k)]
1776 return arrays_out[0]
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:1774, in <listcomp>(.0)
1772 k = 16
1773 while len(arrays_out) > 1:
-> 1774 arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
1775 for i in range(0, len(arrays_out), k)]
1776 return arrays_out[0]
[... skipping hidden 7 frame]
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/jax/_src/lax/lax.py:2908, in _concatenate_shape_rule(*operands, **kwargs)
2906 if len({operand.ndim for operand in operands}) != 1:
2907 msg = "Cannot concatenate arrays with different numbers of dimensions: got {}."
-> 2908 raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
2909 if not 0 <= dimension < operands[0].ndim:
2910 msg = "concatenate dimension out of bounds: dimension {} for shapes {}."
TypeError: Cannot concatenate arrays with different numbers of dimensions: got (3,), (100, 3).
Apply node that caused the error: forall_inplace,cpu,scan_fn}(TensorConstant{100}, IncSubtensor{InplaceSet;:int64:}.0, A)
Toposort index: 5
Inputs types: [TensorType(int8, ()), TensorType(float64, (100, ?)), TensorType(float64, (?, ?))]
Inputs shapes: [(3, 3), (3,)]
Inputs strides: [(24, 8), (8,)]
Inputs values: ['not shown', array([ 1.17146832, -1.98572717, -0.50107902])]
Outputs clients: [['output']]
HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
PyTensor version information:
2.11.2
Context for the issue:
No response