Skip to content

Commit 78939ab

Browse files
brandonwillardricardoV94
authored andcommitted
Make tests compatible with newer version of JAX
1 parent 5affc30 commit 78939ab

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

tests/link/jax/test_basic.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def compare_jax_and_py(
7171

7272
if must_be_device_array:
7373
if isinstance(jax_res, list):
74-
assert all(
75-
isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
76-
)
74+
assert all(isinstance(res, jax.Array) for res in jax_res)
7775
else:
7876
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
7977

@@ -146,13 +144,13 @@ def test_shared():
146144
pytensor_jax_fn = function([], a, mode="JAX")
147145
jax_res = pytensor_jax_fn()
148146

149-
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
147+
assert isinstance(jax_res, jax.Array)
150148
np.testing.assert_allclose(jax_res, a.get_value())
151149

152150
pytensor_jax_fn = function([], a * 2, mode="JAX")
153151
jax_res = pytensor_jax_fn()
154152

155-
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
153+
assert isinstance(jax_res, jax.Array)
156154
np.testing.assert_allclose(jax_res, a.get_value() * 2)
157155

158156
# Changed the shared value and make sure that the JAX-compiled
@@ -161,7 +159,7 @@ def test_shared():
161159
a.set_value(new_a_value)
162160

163161
jax_res = pytensor_jax_fn()
164-
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
162+
assert isinstance(jax_res, jax.Array)
165163
np.testing.assert_allclose(jax_res, new_a_value * 2)
166164

167165

0 commit comments

Comments
 (0)