Skip to content

Fix JAX Scan for output ndim > 1 #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 10, 2023
7 changes: 4 additions & 3 deletions pytensor/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ def get_partial_traces(traces):
for init_state, trace, buffer in zip(init_states, traces, buffers):
if init_state is not None:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
full_trace = jnp.concatenate(
[jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
axis=0,
trace = jnp.atleast_1d(trace)
init_state = jnp.expand_dims(
init_state, range(trace.ndim - init_state.ndim)
)
full_trace = jnp.concatenate([init_state, trace], axis=0)
buffer_size = buffer.shape[0]
else:
# NIT-SOT: Buffer is just the number of entries that should be returned
Expand Down
103 changes: 102 additions & 1 deletion tests/link/jax/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pytensor.scan.op import Scan
from pytensor.tensor import random
from pytensor.tensor.math import gammaln, log
from pytensor.tensor.type import lscalar, scalar, vector
from pytensor.tensor.type import dmatrix, dvector, lscalar, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py


Expand Down Expand Up @@ -317,3 +317,104 @@ def input_step_fn(y_tm1, y_tm3, a):

test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals)


@pytest.mark.parametrize("x0_func", [dvector, dmatrix])
@pytest.mark.parametrize("A_func", [dmatrix, dmatrix])
def test_nd_scan_sit_sot(x0_func, A_func):
x0 = x0_func("x0")
A = A_func("A")

n_steps = 3
k = 3

# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
lambda X, A: A @ X,
non_sequences=[A],
outputs_info=[x0],
n_steps=n_steps,
mode=get_mode("JAX"),
)

x0_val = (
np.arange(k, dtype=config.floatX)
if x0.ndim == 1
else np.diag(np.arange(k, dtype=config.floatX))
)
A_val = np.eye(k, dtype=config.floatX)

fg = FunctionGraph([x0, A], [xs])
test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals)


def test_nd_scan_sit_sot_with_seq():
n_steps = 3
k = 3

x = at.matrix("x0", shape=(n_steps, k))
A = at.matrix("A", shape=(k, k))

# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
lambda X, A: A @ X,
non_sequences=[A],
sequences=[x],
n_steps=n_steps,
mode=get_mode("JAX"),
)

x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
A_val = np.eye(k, dtype=config.floatX)

fg = FunctionGraph([x, A], [xs])
test_input_vals = [x_val, A_val]
compare_jax_and_py(fg, test_input_vals)


def test_nd_scan_mit_sot():
x0 = at.matrix("x0", shape=(3, 3))
A = at.matrix("A", shape=(3, 3))
B = at.matrix("B", shape=(3, 3))

# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B],
n_steps=10,
mode=get_mode("JAX"),
)

fg = FunctionGraph([x0, A, B], [xs])
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
A_val = np.eye(3, dtype=config.floatX)
B_val = np.eye(3, dtype=config.floatX)

test_input_vals = [x0_val, A_val, B_val]
compare_jax_and_py(fg, test_input_vals)


def test_nd_scan_sit_sot_with_carry():
x0 = at.vector("x0", shape=(3,))
A = at.matrix("A", shape=(3, 3))

def step(x, A):
return A @ x, x.sum()

# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
step,
outputs_info=[x0, None],
non_sequences=[A],
n_steps=10,
mode=get_mode("JAX"),
)

fg = FunctionGraph([x0, A], xs)
x0_val = np.arange(3, dtype=config.floatX)
A_val = np.eye(3, dtype=config.floatX)

test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals)