13
13
from pytensor .scan .op import Scan
14
14
from pytensor .tensor import random
15
15
from pytensor .tensor .math import gammaln , log
16
- from pytensor .tensor .type import lscalar , scalar , vector , dmatrix , dvector
16
+ from pytensor .tensor .type import dmatrix , dvector , lscalar , scalar , vector
17
17
from tests .link .jax .test_basic import compare_jax_and_py
18
18
19
19
20
- # jax = pytest.importorskip("jax")
21
-
22
- import jax
23
- jax .config .update ('jax_platform_name' , 'cpu' )
20
+ jax = pytest .importorskip ("jax" )
24
21
25
22
26
23
@pytest .mark .parametrize ("view" , [None , (- 1 ,), slice (- 2 , None , None )])
@@ -322,22 +319,24 @@ def input_step_fn(y_tm1, y_tm3, a):
322
319
compare_jax_and_py (out_fg , test_input_vals )
323
320
324
321
325
- @pytest .mark .parametrize (' x0_func' , [dvector , dmatrix ])
326
- @pytest .mark .parametrize (' A_func' , [dmatrix , dmatrix ])
322
+ @pytest .mark .parametrize (" x0_func" , [dvector , dmatrix ])
323
+ @pytest .mark .parametrize (" A_func" , [dmatrix , dmatrix ])
327
324
def test_nd_scan_sit_sot (x0_func , A_func ):
328
- x0 = x0_func ('x0' )
329
- A = A_func ('A' )
325
+ x0 = x0_func ("x0" )
326
+ A = A_func ("A" )
330
327
331
328
n_steps = 3
332
329
k = 3
333
330
334
331
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
335
- xs , _ = scan (lambda X , A : A @ X ,
336
- non_sequences = [A ],
337
- outputs_info = [x0 ],
338
- n_steps = n_steps ,
339
- mode = get_mode ('JAX' ))
340
-
332
+ xs , _ = scan (
333
+ lambda X , A : A @ X ,
334
+ non_sequences = [A ],
335
+ outputs_info = [x0 ],
336
+ n_steps = n_steps ,
337
+ mode = get_mode ("JAX" ),
338
+ )
339
+
341
340
x0_val = np .arange (k ) if x0 .ndim == 1 else np .diag (np .arange (k ))
342
341
A_val = np .eye (k )
343
342
@@ -347,19 +346,21 @@ def test_nd_scan_sit_sot(x0_func, A_func):
347
346
348
347
349
348
def test_nd_scan_sit_sot_with_seq ():
350
- x = dmatrix ('x0' )
351
- A = dmatrix ('A' )
349
+ x = dmatrix ("x0" )
350
+ A = dmatrix ("A" )
352
351
353
352
n_steps = 3
354
353
k = 3
355
354
356
355
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
357
- xs , _ = scan (lambda X , A : A @ X ,
358
- non_sequences = [A ],
359
- sequences = [x ],
360
- n_steps = n_steps ,
361
- mode = get_mode ('JAX' ))
362
-
356
+ xs , _ = scan (
357
+ lambda X , A : A @ X ,
358
+ non_sequences = [A ],
359
+ sequences = [x ],
360
+ n_steps = n_steps ,
361
+ mode = get_mode ("JAX" ),
362
+ )
363
+
363
364
x_val = np .tile (np .arange (k ), n_steps ).reshape (n_steps , k )
364
365
A_val = np .eye (k )
365
366
@@ -369,17 +370,17 @@ def test_nd_scan_sit_sot_with_seq():
369
370
370
371
371
372
def test_nd_scan_mit_sot ():
372
- x0 = dmatrix ('x0' )
373
- A = dmatrix ('A' )
374
- B = dmatrix ('B' )
373
+ x0 = dmatrix ("x0" )
374
+ A = dmatrix ("A" )
375
+ B = dmatrix ("B" )
375
376
376
377
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
377
378
xs , _ = scan (
378
379
lambda xtm3 , xtm1 , A , B : A @ xtm3 + B @ xtm1 ,
379
380
outputs_info = [{"initial" : x0 , "taps" : [- 3 , - 1 ]}],
380
381
non_sequences = [A , B ],
381
382
n_steps = 10 ,
382
- mode = get_mode (' JAX' )
383
+ mode = get_mode (" JAX" ),
383
384
)
384
385
385
386
fg = FunctionGraph ([x0 , A , B ], [xs ])
@@ -392,8 +393,8 @@ def test_nd_scan_mit_sot():
392
393
393
394
394
395
def test_nd_scan_sit_sot_with_carry ():
395
- x0 = dvector ('x0' )
396
- A = dmatrix ('A' )
396
+ x0 = dvector ("x0" )
397
+ A = dmatrix ("A" )
397
398
398
399
def step (x , A ):
399
400
return A @ x , x .sum ()
@@ -404,7 +405,7 @@ def step(x, A):
404
405
outputs_info = [x0 , None ],
405
406
non_sequences = [A ],
406
407
n_steps = 10 ,
407
- mode = get_mode (' JAX' )
408
+ mode = get_mode (" JAX" ),
408
409
)
409
410
410
411
fg = FunctionGraph ([x0 , A ], xs )
0 commit comments