Skip to content

Commit 3e71b40

Browse files
set dtype of test inputs to
1 parent aebbef2 commit 3e71b40

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tests/link/jax/test_scan.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,10 @@ def test_nd_scan_sit_sot(x0_func, A_func):
337337
mode=get_mode("JAX"),
338338
)
339339

340-
x0_val = np.arange(k) if x0.ndim == 1 else np.diag(np.arange(k))
341-
A_val = np.eye(k)
340+
x0_val = (
341+
np.arange(k, dtype=config.floatX) if x0.ndim == 1 else np.diag(np.arange(k))
342+
)
343+
A_val = np.eye(k, dtype=config.floatX)
342344

343345
fg = FunctionGraph([x0, A], [xs])
344346
test_input_vals = [x0_val, A_val]
@@ -361,8 +363,8 @@ def test_nd_scan_sit_sot_with_seq():
361363
mode=get_mode("JAX"),
362364
)
363365

364-
x_val = np.tile(np.arange(k), n_steps).reshape(n_steps, k)
365-
A_val = np.eye(k)
366+
x_val = np.tile(np.arange(k, dtype=config.floatX), n_steps).reshape(n_steps, k)
367+
A_val = np.eye(k, dtype=config.floatX)
366368

367369
fg = FunctionGraph([x, A], [xs])
368370
test_input_vals = [x_val, A_val]
@@ -384,10 +386,9 @@ def test_nd_scan_mit_sot():
384386
)
385387

386388
fg = FunctionGraph([x0, A, B], [xs])
387-
x0_val = np.r_[[np.arange(3).tolist()] * 3]
388-
print(x0_val)
389-
A_val = np.eye(3)
390-
B_val = np.eye(3)
389+
x0_val = np.r_[[np.arange(3, dtype=config.floatX).tolist()] * 3]
390+
A_val = np.eye(3, dtype=config.floatX)
391+
B_val = np.eye(3, dtype=config.floatX)
391392

392393
test_input_vals = [x0_val, A_val, B_val]
393394
compare_jax_and_py(fg, test_input_vals)
@@ -410,8 +411,8 @@ def step(x, A):
410411
)
411412

412413
fg = FunctionGraph([x0, A], xs)
413-
x0_val = np.arange(3)
414-
A_val = np.eye(3)
414+
x0_val = np.arange(3, dtype=config.floatX)
415+
A_val = np.eye(3, dtype=config.floatX)
415416

416417
test_input_vals = [x0_val, A_val]
417418
compare_jax_and_py(fg, test_input_vals)

0 commit comments

Comments
 (0)