Skip to content

Commit 62cee00

Browse files
committed
Test that JAX scan can handle simple dynamic sequences lengths
1 parent 82823de commit 62cee00

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tests/link/jax/test_scan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,13 @@ def test_default_mode_excludes_incompatible_rewrites():
427427
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
428428
fg = FunctionGraph([A, B], [out])
429429
compare_jax_and_py(fg, [np.eye(3), np.eye(3)])
430+
431+
432+
def test_dynamic_sequence_length():
433+
x = pt.tensor("x", shape=(None,))
434+
out, _ = scan(lambda x: x + 1, sequences=[x])
435+
436+
f = function([x], out, mode=get_mode("JAX").excluding("scan"))
437+
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
438+
np.testing.assert_allclose(f([]), [])
439+
np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4]))

0 commit comments

Comments
 (0)