Skip to content

Commit 35aa643

Browse files
track down missing dtype=config.floatX in test_scan.py
1 parent 3e71b40 commit 35aa643

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tests/link/jax/test_scan.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import re
22

3+
# jax = pytest.importorskip("jax")
4+
import jax
35
import numpy as np
46
import pytest
57

@@ -17,7 +19,8 @@
1719
from tests.link.jax.test_basic import compare_jax_and_py
1820

1921

20-
jax = pytest.importorskip("jax")
22+
jax.config.update("jax_platform_name", "cpu")
23+
config.floatX = "float32"
2124

2225

2326
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
@@ -338,7 +341,9 @@ def test_nd_scan_sit_sot(x0_func, A_func):
338341
)
339342

340343
x0_val = (
341-
np.arange(k, dtype=config.floatX) if x0.ndim == 1 else np.diag(np.arange(k))
344+
np.arange(k, dtype=config.floatX)
345+
if x0.ndim == 1
346+
else np.diag(np.arange(k, dtype=config.floatX))
342347
)
343348
A_val = np.eye(k, dtype=config.floatX)
344349

@@ -363,7 +368,7 @@ def test_nd_scan_sit_sot_with_seq():
363368
mode=get_mode("JAX"),
364369
)
365370

366-
x_val = np.tile(np.arange(k, dtype=config.floatX), n_steps).reshape(n_steps, k)
371+
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
367372
A_val = np.eye(k, dtype=config.floatX)
368373

369374
fg = FunctionGraph([x, A], [xs])
@@ -386,7 +391,7 @@ def test_nd_scan_mit_sot():
386391
)
387392

388393
fg = FunctionGraph([x0, A, B], [xs])
389-
x0_val = np.r_[[np.arange(3, dtype=config.floatX).tolist()] * 3]
394+
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
390395
A_val = np.eye(3, dtype=config.floatX)
391396
B_val = np.eye(3, dtype=config.floatX)
392397

0 commit comments

Comments
 (0)