Skip to content

Commit aebbef2

Browse files
re-write ndim check, specify static sizes for tests
1 parent 83fba01 commit aebbef2

File tree

2 files changed

+14
-26
lines changed

2 files changed

+14
-26
lines changed

pytensor/link/jax/dispatch/scan.py

+5-18
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,6 @@ def scan(*outer_inputs):
4242
op.outer_non_seqs(outer_inputs),
4343
) # JAX `init`
4444

45-
ndim_out_core = [getattr(x, "ndim", 0) for x in op.fgraph.outputs]
46-
# ndim_in_core = [getattr(x, "ndim", 0) for x in mit_sot_init + sit_sot_init]
47-
48-
add_batchdim_flags = [dims_out > 0 for dims_out in ndim_out_core] + [
49-
False
50-
] * op.info.n_nit_sot
51-
5245
def jax_args_to_inner_func_args(carry, x):
5346
"""Convert JAX scan arguments into format expected by scan_inner_func.
5447
@@ -158,20 +151,14 @@ def get_partial_traces(traces):
158151
+ op.outer_nitsot(outer_inputs)
159152
)
160153
partial_traces = []
161-
for init_state, trace, add_batchdim, buffer in zip(
162-
init_states, traces, add_batchdim_flags, buffers
163-
):
154+
for init_state, trace, buffer in zip(init_states, traces, buffers):
164155
if init_state is not None:
165156
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
166-
if add_batchdim:
167-
init_state = jnp.expand_dims(init_state, 0)
168-
full_trace = jnp.concatenate(
169-
[
170-
jnp.atleast_1d(init_state),
171-
jnp.atleast_1d(trace),
172-
],
173-
axis=0,
157+
trace = jnp.atleast_1d(trace)
158+
init_state = jnp.expand_dims(
159+
init_state, range(trace.ndim - init_state.ndim)
174160
)
161+
full_trace = jnp.concatenate([init_state, trace], axis=0)
175162
buffer_size = buffer.shape[0]
176163
else:
177164
# NIT-SOT: Buffer is just the number of entries that should be returned

tests/link/jax/test_scan.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
346346

347347

348348
def test_nd_scan_sit_sot_with_seq():
349-
x = dmatrix("x0")
350-
A = dmatrix("A")
351-
352349
n_steps = 3
353350
k = 3
354351

352+
x = at.matrix("x0", shape=(n_steps, k))
353+
A = at.matrix("A", shape=(k, k))
354+
355355
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
356356
xs, _ = scan(
357357
lambda X, A: A @ X,
@@ -370,9 +370,9 @@ def test_nd_scan_sit_sot_with_seq():
370370

371371

372372
def test_nd_scan_mit_sot():
373-
x0 = dmatrix("x0")
374-
A = dmatrix("A")
375-
B = dmatrix("B")
373+
x0 = at.matrix("x0", shape=(3, 3))
374+
A = at.matrix("A", shape=(3, 3))
375+
B = at.matrix("B", shape=(3, 3))
376376

377377
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
378378
xs, _ = scan(
@@ -385,6 +385,7 @@ def test_nd_scan_mit_sot():
385385

386386
fg = FunctionGraph([x0, A, B], [xs])
387387
x0_val = np.r_[[np.arange(3).tolist()] * 3]
388+
print(x0_val)
388389
A_val = np.eye(3)
389390
B_val = np.eye(3)
390391

@@ -393,8 +394,8 @@ def test_nd_scan_mit_sot():
393394

394395

395396
def test_nd_scan_sit_sot_with_carry():
396-
x0 = dvector("x0")
397-
A = dmatrix("A")
397+
x0 = at.vector("x0", shape=(3,))
398+
A = at.matrix("A", shape=(3, 3))
398399

399400
def step(x, A):
400401
return A @ x, x.sum()

0 commit comments

Comments
 (0)