Skip to content

Commit cffd161

Browse files
committed
Return PyTensor JAX function in compare_jax_and_py helper
1 parent 05d9199 commit cffd161

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

tests/link/jax/test_basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def compare_jax_and_py(
8989
else:
9090
assert_fn(jax_res, py_res)
9191

92-
return jax_res
92+
return pytensor_jax_fn, jax_res
9393

9494

9595
def test_jax_FunctionGraph_once():

tests/link/jax/test_slinalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_jax_basic():
3333
np.tile(np.arange(10), (10, 1)).astype(config.floatX),
3434
np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
3535
]
36-
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
36+
_, [jax_res] = compare_jax_and_py(out_fg, test_input_vals)
3737

3838
# Confirm that the `Subtensor` slice operations are correct
3939
assert jax_res.shape == (5, 3)

tests/link/jax/test_tensor_basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_jax_Alloc():
1515
x = at.alloc(0.0, 2, 3)
1616
x_fg = FunctionGraph([], [x])
1717

18-
(jax_res,) = compare_jax_and_py(x_fg, [])
18+
_, [jax_res] = compare_jax_and_py(x_fg, [])
1919

2020
assert jax_res.shape == (2, 3)
2121

0 commit comments

Comments
 (0)