Skip to content

Commit c84bd0b

Browse files
run pre-commit hooks
1 parent 6159839 commit c84bd0b

File tree

2 files changed

+44
-38
lines changed

2 files changed

+44
-38
lines changed

pytensor/link/jax/dispatch/scan.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pytensor.link.jax.dispatch.basic import jax_funcify
55
from pytensor.scan.op import Scan
6-
from pytensor.compile.sharedvalue import SharedVariable
6+
77

88
@jax_funcify.register(Scan)
99
def jax_funcify_Scan(op: Scan, **kwargs):
@@ -42,11 +42,12 @@ def scan(*outer_inputs):
4242
op.outer_non_seqs(outer_inputs),
4343
) # JAX `init`
4444

45-
ndim_out_core = [getattr(x, 'ndim', 0) for x in op.fgraph.outputs]
46-
ndim_in_core = [getattr(x, 'ndim', 0) for x in mit_sot_init + sit_sot_init]
45+
ndim_out_core = [getattr(x, "ndim", 0) for x in op.fgraph.outputs]
46+
ndim_in_core = [getattr(x, "ndim", 0) for x in mit_sot_init + sit_sot_init]
4747

48-
n_batchdims = [int(dims_out > 0) for dims_out, dims_in in zip(ndim_out_core, ndim_in_core)] +\
49-
[0] * op.info.n_nit_sot
48+
n_batchdims = [
49+
int(dims_out > 0) for dims_out, dims_in in zip(ndim_out_core, ndim_in_core)
50+
] + [0] * op.info.n_nit_sot
5051

5152
def jax_args_to_inner_func_args(carry, x):
5253
"""Convert JAX scan arguments into format expected by scan_inner_func.
@@ -157,13 +158,17 @@ def get_partial_traces(traces):
157158
+ op.outer_nitsot(outer_inputs)
158159
)
159160
partial_traces = []
160-
for init_state, trace, n_batchdim, buffer in zip(init_states, traces, n_batchdims, buffers):
161+
for init_state, trace, n_batchdim, buffer in zip(
162+
init_states, traces, n_batchdims, buffers
163+
):
161164
if init_state is not None:
162165
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
163166
batch_idx = range(n_batchdim)
164167
full_trace = jnp.concatenate(
165-
[jnp.atleast_1d(jnp.expand_dims(init_state, batch_idx)),
166-
jnp.atleast_1d(trace)],
168+
[
169+
jnp.atleast_1d(jnp.expand_dims(init_state, batch_idx)),
170+
jnp.atleast_1d(trace),
171+
],
167172
axis=0,
168173
)
169174
buffer_size = buffer.shape[0]

tests/link/jax/test_scan.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
from pytensor.scan.op import Scan
1414
from pytensor.tensor import random
1515
from pytensor.tensor.math import gammaln, log
16-
from pytensor.tensor.type import lscalar, scalar, vector, dmatrix, dvector
16+
from pytensor.tensor.type import dmatrix, dvector, lscalar, scalar, vector
1717
from tests.link.jax.test_basic import compare_jax_and_py
1818

1919

20-
# jax = pytest.importorskip("jax")
21-
22-
import jax
23-
jax.config.update('jax_platform_name', 'cpu')
20+
jax = pytest.importorskip("jax")
2421

2522

2623
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
@@ -322,22 +319,24 @@ def input_step_fn(y_tm1, y_tm3, a):
322319
compare_jax_and_py(out_fg, test_input_vals)
323320

324321

325-
@pytest.mark.parametrize('x0_func', [dvector, dmatrix])
326-
@pytest.mark.parametrize('A_func', [dmatrix, dmatrix])
322+
@pytest.mark.parametrize("x0_func", [dvector, dmatrix])
323+
@pytest.mark.parametrize("A_func", [dmatrix, dmatrix])
327324
def test_nd_scan_sit_sot(x0_func, A_func):
328-
x0 = x0_func('x0')
329-
A = A_func('A')
325+
x0 = x0_func("x0")
326+
A = A_func("A")
330327

331328
n_steps = 3
332329
k = 3
333330

334331
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
335-
xs, _ = scan(lambda X, A: A @ X,
336-
non_sequences=[A],
337-
outputs_info=[x0],
338-
n_steps=n_steps,
339-
mode=get_mode('JAX'))
340-
332+
xs, _ = scan(
333+
lambda X, A: A @ X,
334+
non_sequences=[A],
335+
outputs_info=[x0],
336+
n_steps=n_steps,
337+
mode=get_mode("JAX"),
338+
)
339+
341340
x0_val = np.arange(k) if x0.ndim == 1 else np.diag(np.arange(k))
342341
A_val = np.eye(k)
343342

@@ -347,19 +346,21 @@ def test_nd_scan_sit_sot(x0_func, A_func):
347346

348347

349348
def test_nd_scan_sit_sot_with_seq():
350-
x = dmatrix('x0')
351-
A = dmatrix('A')
349+
x = dmatrix("x0")
350+
A = dmatrix("A")
352351

353352
n_steps = 3
354353
k = 3
355354

356355
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
357-
xs, _ = scan(lambda X, A: A @ X,
358-
non_sequences=[A],
359-
sequences=[x],
360-
n_steps=n_steps,
361-
mode=get_mode('JAX'))
362-
356+
xs, _ = scan(
357+
lambda X, A: A @ X,
358+
non_sequences=[A],
359+
sequences=[x],
360+
n_steps=n_steps,
361+
mode=get_mode("JAX"),
362+
)
363+
363364
x_val = np.tile(np.arange(k), n_steps).reshape(n_steps, k)
364365
A_val = np.eye(k)
365366

@@ -369,17 +370,17 @@ def test_nd_scan_sit_sot_with_seq():
369370

370371

371372
def test_nd_scan_mit_sot():
372-
x0 = dmatrix('x0')
373-
A = dmatrix('A')
374-
B = dmatrix('B')
373+
x0 = dmatrix("x0")
374+
A = dmatrix("A")
375+
B = dmatrix("B")
375376

376377
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
377378
xs, _ = scan(
378379
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
379380
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
380381
non_sequences=[A, B],
381382
n_steps=10,
382-
mode = get_mode('JAX')
383+
mode=get_mode("JAX"),
383384
)
384385

385386
fg = FunctionGraph([x0, A, B], [xs])
@@ -392,8 +393,8 @@ def test_nd_scan_mit_sot():
392393

393394

394395
def test_nd_scan_sit_sot_with_carry():
395-
x0 = dvector('x0')
396-
A = dmatrix('A')
396+
x0 = dvector("x0")
397+
A = dmatrix("A")
397398

398399
def step(x, A):
399400
return A @ x, x.sum()
@@ -404,7 +405,7 @@ def step(x, A):
404405
outputs_info=[x0, None],
405406
non_sequences=[A],
406407
n_steps=10,
407-
mode = get_mode('JAX')
408+
mode=get_mode("JAX"),
408409
)
409410

410411
fg = FunctionGraph([x0, A], xs)

0 commit comments

Comments
 (0)