@@ -71,9 +71,7 @@ def compare_jax_and_py(
71
71
72
72
if must_be_device_array :
73
73
if isinstance (jax_res , list ):
74
- assert all (
75
- isinstance (res , jax .interpreters .xla .DeviceArray ) for res in jax_res
76
- )
74
+ assert all (isinstance (res , jax .Array ) for res in jax_res )
77
75
else :
78
76
assert isinstance (jax_res , jax .interpreters .xla .DeviceArray )
79
77
@@ -146,13 +144,13 @@ def test_shared():
146
144
pytensor_jax_fn = function ([], a , mode = "JAX" )
147
145
jax_res = pytensor_jax_fn ()
148
146
149
- assert isinstance (jax_res , jax .interpreters . xla . DeviceArray )
147
+ assert isinstance (jax_res , jax .Array )
150
148
np .testing .assert_allclose (jax_res , a .get_value ())
151
149
152
150
pytensor_jax_fn = function ([], a * 2 , mode = "JAX" )
153
151
jax_res = pytensor_jax_fn ()
154
152
155
- assert isinstance (jax_res , jax .interpreters . xla . DeviceArray )
153
+ assert isinstance (jax_res , jax .Array )
156
154
np .testing .assert_allclose (jax_res , a .get_value () * 2 )
157
155
158
156
# Changed the shared value and make sure that the JAX-compiled
@@ -161,7 +159,7 @@ def test_shared():
161
159
a .set_value (new_a_value )
162
160
163
161
jax_res = pytensor_jax_fn ()
164
- assert isinstance (jax_res , jax .interpreters . xla . DeviceArray )
162
+ assert isinstance (jax_res , jax .Array )
165
163
np .testing .assert_allclose (jax_res , new_a_value * 2 )
166
164
167
165
0 commit comments