Skip to content

Commit 05d9199

Browse files
committed
Allow overriding modes used in compare_jax_and_py helper
1 parent 29169d4 commit 05d9199

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/link/jax/test_basic.py

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def compare_jax_and_py(
4040
test_inputs: Iterable,
4141
assert_fn: Optional[Callable] = None,
4242
must_be_device_array: bool = True,
43+
jax_mode=jax_mode,
44+
py_mode=py_mode,
4345
):
4446
"""Function to compare python graph output and jax compiled output for testing equality
4547

0 commit comments

Comments
 (0)