@@ -337,8 +337,10 @@ def test_nd_scan_sit_sot(x0_func, A_func):
337
337
mode = get_mode ("JAX" ),
338
338
)
339
339
340
- x0_val = np .arange (k ) if x0 .ndim == 1 else np .diag (np .arange (k ))
341
- A_val = np .eye (k )
340
+ x0_val = (
341
+ np .arange (k , dtype = config .floatX ) if x0 .ndim == 1 else np .diag (np .arange (k ))
342
+ )
343
+ A_val = np .eye (k , dtype = config .floatX )
342
344
343
345
fg = FunctionGraph ([x0 , A ], [xs ])
344
346
test_input_vals = [x0_val , A_val ]
@@ -361,8 +363,8 @@ def test_nd_scan_sit_sot_with_seq():
361
363
mode = get_mode ("JAX" ),
362
364
)
363
365
364
- x_val = np .tile (np .arange (k ), n_steps ).reshape (n_steps , k )
365
- A_val = np .eye (k )
366
+ x_val = np .tile (np .arange (k , dtype = config . floatX ), n_steps ).reshape (n_steps , k )
367
+ A_val = np .eye (k , dtype = config . floatX )
366
368
367
369
fg = FunctionGraph ([x , A ], [xs ])
368
370
test_input_vals = [x_val , A_val ]
@@ -384,10 +386,9 @@ def test_nd_scan_mit_sot():
384
386
)
385
387
386
388
fg = FunctionGraph ([x0 , A , B ], [xs ])
387
- x0_val = np .r_ [[np .arange (3 ).tolist ()] * 3 ]
388
- print (x0_val )
389
- A_val = np .eye (3 )
390
- B_val = np .eye (3 )
389
+ x0_val = np .r_ [[np .arange (3 , dtype = config .floatX ).tolist ()] * 3 ]
390
+ A_val = np .eye (3 , dtype = config .floatX )
391
+ B_val = np .eye (3 , dtype = config .floatX )
391
392
392
393
test_input_vals = [x0_val , A_val , B_val ]
393
394
compare_jax_and_py (fg , test_input_vals )
@@ -410,8 +411,8 @@ def step(x, A):
410
411
)
411
412
412
413
fg = FunctionGraph ([x0 , A ], xs )
413
- x0_val = np .arange (3 )
414
- A_val = np .eye (3 )
414
+ x0_val = np .arange (3 , dtype = config . floatX )
415
+ A_val = np .eye (3 , dtype = config . floatX )
415
416
416
417
test_input_vals = [x0_val , A_val ]
417
418
compare_jax_and_py (fg , test_input_vals )
0 commit comments