Skip to content

Commit 2dc912d

Browse files
committed
Use JAX mode when testing jax dispatching
1 parent 63f8d6e commit 2dc912d

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

tests/link/jax/test_basic.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
import pytest
66

77
from pytensor.compile.function import function
8-
from pytensor.compile.mode import Mode
8+
from pytensor.compile.mode import get_mode
99
from pytensor.compile.sharedvalue import SharedVariable, shared
1010
from pytensor.configdefaults import config
1111
from pytensor.graph.basic import Apply
1212
from pytensor.graph.fg import FunctionGraph
1313
from pytensor.graph.op import Op, get_test_value
14-
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1514
from pytensor.ifelse import ifelse
16-
from pytensor.link.jax import JAXLinker
1715
from pytensor.raise_op import assert_op
1816
from pytensor.tensor.type import dscalar, scalar, vector
1917

@@ -27,12 +25,9 @@ def set_pytensor_flags():
2725
jax = pytest.importorskip("jax")
2826

2927

30-
jax_mode = Mode(
31-
JAXLinker(), RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
32-
)
33-
py_mode = Mode(
34-
"py", RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
35-
)
28+
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
29+
jax_mode = get_mode("JAX")
30+
py_mode = get_mode("FAST_COMPILE")
3631

3732

3833
def compare_jax_and_py(

tests/link/jax/test_subtensor.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,15 @@ def test_jax_Subtensor_dynamic():
7171

7272
def test_jax_Subtensor_boolean_mask():
7373
"""JAX does not support resizing arrays with boolean masks."""
74-
x_at = at.arange(-5, 5)
74+
x_at = at.vector("x", dtype="float64")
7575
out_at = x_at[x_at < 0]
7676
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
7777

78+
out_fg = FunctionGraph([x_at], [out_at])
79+
80+
x_at_test = np.arange(-5, 5)
7881
with pytest.raises(NotImplementedError, match="resizing arrays with boolean"):
79-
out_fg = FunctionGraph([], [out_at])
80-
compare_jax_and_py(out_fg, [])
82+
compare_jax_and_py(out_fg, [x_at_test])
8183

8284

8385
def test_jax_Subtensor_boolean_mask_reexpressible():

0 commit comments

Comments
 (0)