Skip to content

Commit 6762240

Browse files
committed
Fix failing test due to change in JAX behavior
1 parent ce0b503 commit 6762240

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/link/jax/test_tensor_basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,17 @@ def test_jax_split_not_supported(self):
176176
UserWarning, match="Split node does not have constant split positions."
177177
):
178178
fn = pytensor.function([a], a_splits, mode="JAX")
179-
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpsasses it
179+
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surasses it
180180
with pytest.raises(AttributeError):
181181
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
182182

183183
split_axis = iscalar("split_axis")
184184
a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis)
185185
with pytest.warns(UserWarning, match="Split node does not have constant axis."):
186186
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
187-
with pytest.raises(jax.errors.TracerIntegerConversionError):
187+
# Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
188+
# Both errors are included for backwards compatibility
189+
with pytest.raises((AttributeError, jax.errors.TracerIntegerConversionError)):
188190
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)
189191

190192

0 commit comments

Comments
 (0)