We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
compare_jax_and_py
1 parent 29169d4 commit 05d9199Copy full SHA for 05d9199
tests/link/jax/test_basic.py
@@ -40,6 +40,8 @@ def compare_jax_and_py(
40
test_inputs: Iterable,
41
assert_fn: Optional[Callable] = None,
42
must_be_device_array: bool = True,
43
+ jax_mode=jax_mode,
44
+ py_mode=py_mode,
45
):
46
"""Function to compare python graph output and jax compiled output for testing equality
47
0 commit comments