1
1
import re
2
2
3
+ # jax = pytest.importorskip("jax")
4
+ import jax
3
5
import numpy as np
4
6
import pytest
5
7
17
19
from tests .link .jax .test_basic import compare_jax_and_py
18
20
19
21
20
- jax = pytest .importorskip ("jax" )
22
+ jax .config .update ("jax_platform_name" , "cpu" )
23
+ config .floatX = "float32"
21
24
22
25
23
26
@pytest .mark .parametrize ("view" , [None , (- 1 ,), slice (- 2 , None , None )])
@@ -338,7 +341,9 @@ def test_nd_scan_sit_sot(x0_func, A_func):
338
341
)
339
342
340
343
x0_val = (
341
- np .arange (k , dtype = config .floatX ) if x0 .ndim == 1 else np .diag (np .arange (k ))
344
+ np .arange (k , dtype = config .floatX )
345
+ if x0 .ndim == 1
346
+ else np .diag (np .arange (k , dtype = config .floatX ))
342
347
)
343
348
A_val = np .eye (k , dtype = config .floatX )
344
349
@@ -363,7 +368,7 @@ def test_nd_scan_sit_sot_with_seq():
363
368
mode = get_mode ("JAX" ),
364
369
)
365
370
366
- x_val = np .tile ( np . arange (k , dtype = config .floatX ), n_steps ).reshape (n_steps , k )
371
+ x_val = np .arange (n_steps * k , dtype = config .floatX ).reshape (n_steps , k )
367
372
A_val = np .eye (k , dtype = config .floatX )
368
373
369
374
fg = FunctionGraph ([x , A ], [xs ])
@@ -386,7 +391,7 @@ def test_nd_scan_mit_sot():
386
391
)
387
392
388
393
fg = FunctionGraph ([x0 , A , B ], [xs ])
389
- x0_val = np .r_ [[ np . arange (3 , dtype = config .floatX ).tolist ()] * 3 ]
394
+ x0_val = np .arange (9 , dtype = config .floatX ).reshape ( 3 , 3 )
390
395
A_val = np .eye (3 , dtype = config .floatX )
391
396
B_val = np .eye (3 , dtype = config .floatX )
392
397
0 commit comments