Skip to content

Commit 9ae07ab

Browse files
Fix JAX Scan for output ndim > 1 (#288)
1 parent cb417fe commit 9ae07ab

File tree

2 files changed

+106
-4
lines changed

2 files changed

+106
-4
lines changed

pytensor/link/jax/dispatch/scan.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,11 @@ def get_partial_traces(traces):
154154
for init_state, trace, buffer in zip(init_states, traces, buffers):
155155
if init_state is not None:
156156
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
157-
full_trace = jnp.concatenate(
158-
[jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
159-
axis=0,
157+
trace = jnp.atleast_1d(trace)
158+
init_state = jnp.expand_dims(
159+
init_state, range(trace.ndim - init_state.ndim)
160160
)
161+
full_trace = jnp.concatenate([init_state, trace], axis=0)
161162
buffer_size = buffer.shape[0]
162163
else:
163164
# NIT-SOT: Buffer is just the number of entries that should be returned

tests/link/jax/test_scan.py

+102-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pytensor.scan.op import Scan
1414
from pytensor.tensor import random
1515
from pytensor.tensor.math import gammaln, log
16-
from pytensor.tensor.type import lscalar, scalar, vector
16+
from pytensor.tensor.type import dmatrix, dvector, lscalar, scalar, vector
1717
from tests.link.jax.test_basic import compare_jax_and_py
1818

1919

@@ -317,3 +317,104 @@ def input_step_fn(y_tm1, y_tm3, a):
317317

318318
test_input_vals = [np.array(10.0).astype(config.floatX)]
319319
compare_jax_and_py(out_fg, test_input_vals)
320+
321+
322+
@pytest.mark.parametrize("x0_func", [dvector, dmatrix])
323+
@pytest.mark.parametrize("A_func", [dmatrix, dmatrix])
324+
def test_nd_scan_sit_sot(x0_func, A_func):
325+
x0 = x0_func("x0")
326+
A = A_func("A")
327+
328+
n_steps = 3
329+
k = 3
330+
331+
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
332+
xs, _ = scan(
333+
lambda X, A: A @ X,
334+
non_sequences=[A],
335+
outputs_info=[x0],
336+
n_steps=n_steps,
337+
mode=get_mode("JAX"),
338+
)
339+
340+
x0_val = (
341+
np.arange(k, dtype=config.floatX)
342+
if x0.ndim == 1
343+
else np.diag(np.arange(k, dtype=config.floatX))
344+
)
345+
A_val = np.eye(k, dtype=config.floatX)
346+
347+
fg = FunctionGraph([x0, A], [xs])
348+
test_input_vals = [x0_val, A_val]
349+
compare_jax_and_py(fg, test_input_vals)
350+
351+
352+
def test_nd_scan_sit_sot_with_seq():
353+
n_steps = 3
354+
k = 3
355+
356+
x = at.matrix("x0", shape=(n_steps, k))
357+
A = at.matrix("A", shape=(k, k))
358+
359+
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
360+
xs, _ = scan(
361+
lambda X, A: A @ X,
362+
non_sequences=[A],
363+
sequences=[x],
364+
n_steps=n_steps,
365+
mode=get_mode("JAX"),
366+
)
367+
368+
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
369+
A_val = np.eye(k, dtype=config.floatX)
370+
371+
fg = FunctionGraph([x, A], [xs])
372+
test_input_vals = [x_val, A_val]
373+
compare_jax_and_py(fg, test_input_vals)
374+
375+
376+
def test_nd_scan_mit_sot():
377+
x0 = at.matrix("x0", shape=(3, 3))
378+
A = at.matrix("A", shape=(3, 3))
379+
B = at.matrix("B", shape=(3, 3))
380+
381+
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
382+
xs, _ = scan(
383+
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
384+
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
385+
non_sequences=[A, B],
386+
n_steps=10,
387+
mode=get_mode("JAX"),
388+
)
389+
390+
fg = FunctionGraph([x0, A, B], [xs])
391+
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
392+
A_val = np.eye(3, dtype=config.floatX)
393+
B_val = np.eye(3, dtype=config.floatX)
394+
395+
test_input_vals = [x0_val, A_val, B_val]
396+
compare_jax_and_py(fg, test_input_vals)
397+
398+
399+
def test_nd_scan_sit_sot_with_carry():
400+
x0 = at.vector("x0", shape=(3,))
401+
A = at.matrix("A", shape=(3, 3))
402+
403+
def step(x, A):
404+
return A @ x, x.sum()
405+
406+
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
407+
xs, _ = scan(
408+
step,
409+
outputs_info=[x0, None],
410+
non_sequences=[A],
411+
n_steps=10,
412+
mode=get_mode("JAX"),
413+
)
414+
415+
fg = FunctionGraph([x0, A], xs)
416+
x0_val = np.arange(3, dtype=config.floatX)
417+
A_val = np.eye(3, dtype=config.floatX)
418+
419+
test_input_vals = [x0_val, A_val]
420+
compare_jax_and_py(fg, test_input_vals)

0 commit comments

Comments
 (0)